Source code for trinity.trainer.trainer

# -*- coding: utf-8 -*-
"""
Trainer Class
"""
from __future__ import annotations

import os
from abc import ABC, abstractmethod

import ray

from trinity.common.config import Config
from trinity.common.constants import RunningStatus, SyncMethod
from trinity.utils.log import get_logger


[docs] class Trainer: """Consume the experience and train the model."""
[docs] def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) self.engine = get_trainer_wrapper(config) self.explorer_ref = None
[docs] def prepare(self) -> None: """Prepare the trainer.""" self.engine.prepare()
[docs] def train(self) -> str: """Train the model.""" while True: try: train_continue = self.train_step() if not train_continue: break if self.need_sync(): self.sync_weight() except Exception as e: self.logger.error(f"Error in Trainer: {e}") break self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name
[docs] def train_step(self) -> bool: """Train one step. Returns: bool: Whether to continue training. """ return self.engine.train_step()
[docs] def need_sync(self) -> bool: """Whether to sync the model weight.""" return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0
[docs] def sync_weight(self) -> None: """Sync the model weight.""" if self.config.synchronizer.sync_method == SyncMethod.NCCL: if self.explorer_ref is None: self.explorer_ref = ray.get_actor(self.config.explorer.name) explorer_status = ray.get(self.explorer_ref.running_status.remote()) if explorer_status == RunningStatus.STOPPED: self.logger.warning("Explorer has already stopped. Skipping sync weight.") return self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.") self.engine.sync_weight()
[docs] def flush_log(self, step: int) -> None: """Flush the log of the current step.""" self.engine.logger.log({}, step=step, commit=True)
[docs] def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint step_num = self.engine.train_step_num path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}") if not os.path.isdir(path) or len(os.listdir(path)) == 0: self.engine.save_checkpoint() self.engine.logger.close()
[docs] class TrainEngineWrapper(ABC): """A wrapper class to wrap various training engines."""
[docs] @abstractmethod def prepare(self) -> None: """Do some preparation before training started."""
@property @abstractmethod def train_step_num(self) -> int: """Get the current training step number."""
[docs] @abstractmethod def train_step(self) -> bool: """Training."""
[docs] @abstractmethod def save_checkpoint(self) -> None: """Save the checkpoint."""
[docs] @abstractmethod def sync_weight(self) -> None: """Sync the model weight."""
[docs] @abstractmethod def shutdown(self) -> None: """Shutdown the engine."""
[docs] def get_trainer_wrapper(config: Config) -> TrainEngineWrapper: """Get a trainer wrapper.""" if config.trainer.trainer_type == "verl": from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper return VerlPPOTrainerWrapper(config) else: raise NotImplementedError