# -*- coding: utf-8 -*-
"""
Trainer Class
"""
from __future__ import annotations
import asyncio
import traceback
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
import pandas as pd
import ray
from trinity.algorithm import SAMPLE_STRATEGY
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import Config
from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle
from trinity.common.experience import Experiences
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR
[docs]
class Trainer:
"""Consume the experience and train the model."""
[docs]
def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
self.synchronizer = Synchronizer.get_actor(config)
self.engine = get_trainer_wrapper(config)
self.last_trainer_sync_step = 0
self.monitor = MONITOR.get(config.monitor.monitor_type)(
project=config.project,
group=self.config.group,
name=config.name,
role=config.trainer.name,
config=config,
)
self._sample_exps_to_log = []
self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)(
buffer_config=config.buffer,
**config.algorithm.sample_strategy_args,
)
self.train_continue = True
self.last_sync_step = None
[docs]
def prepare(self) -> None:
"""Prepare the trainer."""
self.engine.prepare()
self.last_trainer_sync_step = self.train_step_num
ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING))
[docs]
async def train(self) -> str:
"""Train the model."""
while self.train_continue:
try:
train_task = asyncio.create_task(self.train_step())
while not train_task.done():
if self.need_sync():
self.sync_weight()
await asyncio.sleep(1)
self.train_continue &= await train_task
if self.train_continue and self.need_sync():
self.sync_weight()
except Exception:
self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}")
self.train_continue = False
self.engine.save_checkpoint(block_until_saved=True)
await self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
return self.config.trainer.name
[docs]
async def train_step(self) -> bool:
"""Train one step.
Returns:
bool: Whether to continue training.
"""
self.logger.info(f"Training at step {self.train_step_num + 1} started.")
try:
batch, sample_metrics, repr_samples = await self.sample_strategy.sample(
self.train_step_num + 1
)
except StopAsyncIteration:
self.logger.info("No more samples to train. Stopping training.")
if (
self.config.trainer.save_interval == 0
or self.train_step_num % self.config.trainer.save_interval != 0
):
self.logger.info(f"Saving at step {self.train_step_num}.")
self.engine.save_checkpoint()
self.logger.info(f"Saved at step {self.train_step_num}.")
return False
self.logger.info(f"Sampling at step {self.train_step_num + 1} done.")
continue_run, metrics = self.engine.train_step(batch)
self.logger.info(f"Training at step {self.train_step_num} finished.")
prefix_metrics(sample_metrics, "sample", metrics)
self.monitor.log(data=metrics, step=self.train_step_num)
if self.config.trainer.enable_preview:
self._log_experiences(repr_samples)
return continue_run
[docs]
def need_sync(self) -> bool:
"""Whether to sync the model weight."""
if self.config.synchronizer.sync_style == SyncStyle.FIXED:
return (
self.last_sync_step != self.train_step_num
and self.train_step_num % self.config.synchronizer.sync_interval == 0
)
else:
if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER:
delta = self.train_step_num - self.last_trainer_sync_step
if delta >= self.config.synchronizer.sync_interval:
ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC))
explorer_status_counts = ray.get(self.synchronizer.get_explorer_status_counts.remote())
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0
else: # memory & checkpoint
return explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0
[docs]
def sync_weight(self) -> None:
"""Sync the model weight."""
self.logger.info(f"Trainer synchronizing weights at step {self.train_step_num} starting..")
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
result = ray.get(
self.synchronizer.ready_to_nccl_sync.remote("trainer", self.train_step_num)
)
if result is None:
self.logger.error("Trainer synchronizing weights failed.")
else:
self.engine.sync_weight()
self.last_trainer_sync_step = self.train_step_num
elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
self.engine.save_state_dict()
elif self.config.synchronizer.sync_method == SyncMethod.MEMORY:
self.engine.upload_state_dict()
self.logger.info(f"Trainer synchronizing weights at step {self.train_step_num} end.")
self.last_sync_step = self.train_step_num
ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING))
def _log_experiences(self, samples: List[Dict]) -> None:
self._sample_exps_to_log.extend(samples)
if self.train_step_num % self.config.synchronizer.sync_interval == 0:
self.monitor.log_table(
"rollout_examples", pd.DataFrame(self._sample_exps_to_log), self.train_step_num
)
self._sample_exps_to_log.clear()
[docs]
async def shutdown(self) -> None:
self.monitor.close()
@property
def train_step_num(self) -> int:
"""Get the current training step number."""
return self.engine.train_step_num
[docs]
def is_alive(self) -> bool:
"""Check if the trainer is alive."""
return True
[docs]
class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
[docs]
@abstractmethod
def prepare(self) -> None:
"""Do some preparation before training started."""
@property
@abstractmethod
def train_step_num(self) -> int:
"""Get the current training step number."""
[docs]
@abstractmethod
def train_step(self, batch: Experiences) -> Tuple[bool, Dict]:
"""Training one step.
Args:
batch (Experiences): A batch of experiences to train.
Returns:
bool: Whether to continue training.
Dict: Metrics of the training step.
"""
[docs]
@abstractmethod
def save_checkpoint(self, block_until_saved: bool = False) -> None:
"""Save the checkpoint."""
[docs]
@abstractmethod
def sync_weight(self) -> None:
"""Sync the model weight."""
[docs]
@abstractmethod
def upload_state_dict(self) -> None:
"""Upload the state dict to Synchronizer."""
[docs]
@abstractmethod
def save_state_dict(self) -> None:
"""Only save the model state dict for Synchronizer."""
[docs]
def get_trainer_wrapper(config: Config) -> TrainEngineWrapper:
"""Get a trainer wrapper."""
if config.trainer.trainer_type == "verl":
from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper
return VerlPPOTrainerWrapper(config)
else:
raise NotImplementedError