# -*- coding: utf-8 -*-
"""
Trainer Class
"""
from __future__ import annotations
import os
from abc import ABC, abstractmethod
import ray
from trinity.common.config import Config
from trinity.common.constants import RunningStatus, SyncMethod
from trinity.utils.log import get_logger
[docs]
class Trainer:
"""Consume the experience and train the model."""
[docs]
def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
self.engine = get_trainer_wrapper(config)
self.explorer_ref = None
[docs]
def prepare(self) -> None:
"""Prepare the trainer."""
self.engine.prepare()
[docs]
def train(self) -> str:
"""Train the model."""
while True:
try:
train_continue = self.train_step()
if not train_continue:
break
if self.need_sync():
self.sync_weight()
except Exception as e:
self.logger.error(f"Error in Trainer: {e}")
break
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
return self.config.trainer.name
[docs]
def train_step(self) -> bool:
"""Train one step.
Returns:
bool: Whether to continue training.
"""
return self.engine.train_step()
[docs]
def need_sync(self) -> bool:
"""Whether to sync the model weight."""
return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0
[docs]
def sync_weight(self) -> None:
"""Sync the model weight."""
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
if self.explorer_ref is None:
self.explorer_ref = ray.get_actor(self.config.explorer.name)
explorer_status = ray.get(self.explorer_ref.running_status.remote())
if explorer_status == RunningStatus.STOPPED:
self.logger.warning("Explorer has already stopped. Skipping sync weight.")
return
self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.")
self.engine.sync_weight()
[docs]
def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.engine.logger.log({}, step=step, commit=True)
[docs]
def shutdown(self) -> None:
# if checkpoint not saved, save the last checkpoint
step_num = self.engine.train_step_num
path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}")
if not os.path.isdir(path) or len(os.listdir(path)) == 0:
self.engine.save_checkpoint()
self.engine.logger.close()
[docs]
class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
[docs]
@abstractmethod
def prepare(self) -> None:
"""Do some preparation before training started."""
@property
@abstractmethod
def train_step_num(self) -> int:
"""Get the current training step number."""
[docs]
@abstractmethod
def train_step(self) -> bool:
"""Training."""
[docs]
@abstractmethod
def save_checkpoint(self) -> None:
"""Save the checkpoint."""
[docs]
@abstractmethod
def sync_weight(self) -> None:
"""Sync the model weight."""
[docs]
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the engine."""
[docs]
def get_trainer_wrapper(config: Config) -> TrainEngineWrapper:
"""Get a trainer wrapper."""
if config.trainer.trainer_type == "verl":
from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper
return VerlPPOTrainerWrapper(config)
else:
raise NotImplementedError