# -*- coding: utf-8 -*-
"""
Trainer Class
"""
from __future__ import annotations
import asyncio
import traceback
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
import pandas as pd
import ray
from trinity.algorithm import SAMPLE_STRATEGY
from trinity.common.config import Config
from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle
from trinity.common.experience import Experiences
from trinity.manager.state_manager import StateManager
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR
from trinity.utils.plugin_loader import load_plugins
from trinity.utils.timer import Timer
[docs]
class Trainer:
"""Consume the experience and train the model."""
[docs]
def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(config.trainer.name, in_ray_actor=True)
load_plugins()
self.synchronizer = Synchronizer.get_actor(config)
self.engine = get_trainer_wrapper(config)
self.state = StateManager(
path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config
)
trainer_state = self.state.load_trainer()
config.buffer.trainer_input.experience_buffer.index = trainer_state.get(
"latest_exp_index", 0
)
self.last_trainer_sync_step = 0
self.monitor = MONITOR.get(config.monitor.monitor_type)(
project=config.project,
group=self.config.group,
name=config.name,
role=config.trainer.name,
config=config,
)
self._sample_exps_to_log = []
self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)(
buffer_config=config.buffer,
**config.algorithm.sample_strategy_args,
)
self.save_interval = config.trainer.save_interval
self.last_sync_step = None
self.total_steps = config.trainer.total_steps or float("inf")
[docs]
async def prepare(self) -> None:
"""Prepare the trainer."""
self.engine.prepare()
self.last_trainer_sync_step = self.train_step_num
await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)
[docs]
async def train(self) -> str:
"""Train the model."""
while self.train_step_num < self.total_steps:
try:
# sample may be blocked due to explorer does not generate enough data
self.logger.info(f"Sample data for step {self.train_step_num + 1} started.")
sample_task = asyncio.create_task(self._sample_data())
while not sample_task.done():
# sync weight to make sure the explorer can continue to explore and generate enough data
if await self.need_sync():
# Currently, we do not record the metrics of sync_weight here
await self.sync_weight()
await asyncio.sleep(1)
exps, metrics, repr_samples = await sample_task
self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.")
metrics.update(await self.train_step(exps))
if await self.need_sync():
metrics.update(await self.sync_weight())
if self.need_save():
metrics.update(self.save_checkpoint())
if self.config.trainer.enable_preview:
self._log_experiences(repr_samples)
self.monitor.log(metrics, self.train_step_num)
except StopAsyncIteration:
self.logger.info("No more samples to train. Stopping training.")
break
except Exception:
self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}")
break
self.save_checkpoint(block_until_saved=True, save_as_hf=True)
await self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
return self.config.trainer.name
[docs]
async def train_step(self, exps: Experiences) -> Dict:
"""Train one step.
Returns:
bool: Whether to continue training.
Dict: Metrics of the training step.
"""
self.logger.info(f"Training at step {self.train_step_num + 1} started.")
metrics = {}
with Timer(metrics, "time/train_step"):
train_metrics = self.engine.train_step(exps)
self.logger.info(f"Training at step {self.train_step_num} finished.")
metrics.update(train_metrics)
return metrics
async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
"""Sample a batch of experiences.
Returns:
Experiences: A batch of experiences.
Dict: Metrics of the sampling step.
List[Dict]: A list of representative samples for logging.
"""
with Timer({}, "time/sample_data"):
batch, metrics, repr_samples = await self.sample_strategy.sample(
self.train_step_num + 1
)
return batch, metrics, repr_samples
[docs]
async def need_sync(self) -> bool:
"""Whether to sync the model weight."""
if self.config.synchronizer.sync_style == SyncStyle.FIXED:
return (
self.last_sync_step != self.train_step_num
and self.train_step_num % self.config.synchronizer.sync_interval == 0
)
else:
if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER:
delta = self.train_step_num - self.last_trainer_sync_step
if delta >= self.config.synchronizer.sync_interval:
await self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)
explorer_status_counts = await self.synchronizer.get_explorer_status_counts.remote()
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0
else: # memory & checkpoint
return explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0
[docs]
def need_save(self) -> bool:
"""Whether to save the checkpoint."""
return self.save_interval > 0 and self.train_step_num % self.save_interval == 0
[docs]
async def sync_weight(self) -> Dict:
"""Sync the model weight."""
self.logger.info(f"Trainer synchronizing weights at step {self.train_step_num} starting..")
metrics = {}
with Timer(metrics, "time/sync_weight"):
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
result = await self.synchronizer.ready_to_nccl_sync.remote(
"trainer", self.train_step_num
)
if result is None:
self.logger.error("Trainer synchronizing weights failed.")
else:
self.engine.sync_weight()
self.last_trainer_sync_step = self.train_step_num
elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
self.engine.save_state_dict()
elif self.config.synchronizer.sync_method == SyncMethod.MEMORY:
self.engine.upload_state_dict()
self.last_sync_step = self.train_step_num
await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)
self.logger.info(f"Trainer synchronizing weights at step {self.train_step_num} end.")
return metrics
def _log_experiences(self, samples: List[Dict]) -> None:
self._sample_exps_to_log.extend(samples)
if self.train_step_num % self.config.synchronizer.sync_interval == 0:
self.monitor.log_table(
"rollout_examples", pd.DataFrame(self._sample_exps_to_log), self.train_step_num
)
self._sample_exps_to_log.clear()
[docs]
def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> Dict:
metrics = {}
with Timer(metrics, "time/save_checkpoint"):
self.logger.info(f"Saving checkpoint at step {self.train_step_num}...")
self.engine.save_checkpoint(block_until_saved=block_until_saved, save_as_hf=save_as_hf)
self.state.save_trainer(
current_exp_index=self.engine.train_step_num * self.config.buffer.train_batch_size,
current_step=self.train_step_num,
)
self.logger.info(f"Checkpoint at step {self.train_step_num} saved.")
return metrics
[docs]
async def shutdown(self) -> None:
self.monitor.close()
@property
def train_step_num(self) -> int:
"""Get the current training step number."""
return self.engine.train_step_num
[docs]
async def is_alive(self) -> bool:
"""Check if the trainer is alive."""
return True
[docs]
@classmethod
def get_actor(cls, config: Config):
"""Get a Ray actor for the trainer."""
return (
ray.remote(cls)
.options(name=config.trainer.name, namespace=ray.get_runtime_context().namespace)
.remote(config)
)
[docs]
class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
[docs]
@abstractmethod
def prepare(self) -> None:
"""Do some preparation before training started."""
@property
@abstractmethod
def train_step_num(self) -> int:
"""Get the current training step number."""
[docs]
@abstractmethod
def train_step(self, batch: Experiences) -> Dict:
"""Training one step.
Args:
batch (Experiences): A batch of experiences to train.
Returns:
Dict: Metrics of the training step.
"""
[docs]
@abstractmethod
def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> None:
"""Save the checkpoint."""
[docs]
@abstractmethod
def sync_weight(self) -> None:
"""Sync the model weight."""
[docs]
@abstractmethod
def upload_state_dict(self) -> None:
"""Upload the state dict to Synchronizer."""
[docs]
@abstractmethod
def save_state_dict(self) -> None:
"""Only save the model state dict for Synchronizer."""
[docs]
def get_trainer_wrapper(config: Config) -> TrainEngineWrapper:
"""Get a trainer wrapper."""
if config.trainer.trainer_type == "verl":
from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper
return VerlPPOTrainerWrapper(config)
else:
raise NotImplementedError