trinity.trainer
Subpackages
- trinity.trainer.verl
- Submodules
- trinity.trainer.verl.dp_actor module
- trinity.trainer.verl.fsdp_checkpoint_manager module
- trinity.trainer.verl.fsdp_workers module
ActorRolloutRefWorker
ActorRolloutRefWorker.__init__()
ActorRolloutRefWorker.init_model()
ActorRolloutRefWorker.setup_weight_sync_group()
ActorRolloutRefWorker.sync_weight()
ActorRolloutRefWorker.upload_state_dict()
ActorRolloutRefWorker.set_algorithm()
ActorRolloutRefWorker.update_actor()
ActorRolloutRefWorker.compute_log_prob()
ActorRolloutRefWorker.compute_ref_log_prob()
ActorRolloutRefWorker.save_checkpoint()
ActorRolloutRefWorker.load_checkpoint()
ActorRolloutRefWorker.clear_optimizer_state()
ActorRolloutRefWorker.wait_on_save_thread()
CriticWorker
- trinity.trainer.verl.utils module
- Module contents
Submodules
trinity.trainer.trainer module
Trainer Class
- class trinity.trainer.trainer.Trainer(config: Config)[source]
Bases:
object
Consume the experience and train the model.
- async train_step() bool [source]
Train one step.
- Returns:
Whether to continue training.
- Return type:
bool
- property train_step_num: int
Get the current training step number.
- class trinity.trainer.trainer.TrainEngineWrapper[source]
Bases:
ABC
A wrapper class to wrap various training engines.
- abstract property train_step_num: int
Get the current training step number.
- abstract train_step(batch: Experiences) Tuple[bool, Dict] [source]
Training one step.
- Parameters:
batch (Experiences) – A batch of experiences to train.
- Returns:
Whether to continue training. Dict: Metrics of the training step.
- Return type:
bool
- trinity.trainer.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper [source]
Get a trainer wrapper.
trinity.trainer.verl_trainer module
veRL Trainer Class
Modified from verl/trainer/ppo/ray_trainer.py
- class trinity.trainer.verl_trainer.VerlPPOTrainerWrapper(global_config: Config)[source]
Bases:
RayPPOTrainer
,TrainEngineWrapper
A wrapper for verl.trainer.ppo.RayPPOTrainer.
- __init__(global_config: Config)[source]
Initialize distributed PPO trainer with Ray backend. Note that this trainer runs on the driver process on a single CPU/GPU node.
- Parameters:
config – Configuration object containing training parameters.
tokenizer – Tokenizer used for encoding and decoding text.
role_worker_mapping (dict[Role, WorkerType]) – Mapping from roles to worker classes.
resource_pool_manager (ResourcePoolManager) – Manager for Ray resource pools.
ray_worker_group_cls (RayWorkerGroup, optional) – Class for Ray worker groups. Defaults to RayWorkerGroup.
processor – Optional data processor, used for multimodal data
reward_fn – Function for computing rewards during training.
val_reward_fn – Function for computing rewards during validation.
train_dataset (Optional[Dataset], optional) – Training dataset. Defaults to None.
val_dataset (Optional[Dataset], optional) – Validation dataset. Defaults to None.
collate_fn – Function to collate data samples into batches.
train_sampler (Optional[Sampler], optional) – Sampler for the training dataset. Defaults to None.
device_name (str, optional) – Device name for training (e.g., “cuda”, “cpu”). Defaults to “cuda”.
- init_workers()[source]
Initialize distributed training workers using Ray backend.
Creates:
Ray resource pools from configuration
Worker groups for each role (actor, critic, etc.)
- property train_step_num: int
Get the current training step number.
- train_step(batch: Experiences) Tuple[bool, Dict] [source]
Training one step.
- Parameters:
batch (Experiences) – A batch of experiences to train.
- Returns:
Whether to continue training. Dict: Metrics of the training step.
- Return type:
bool
Module contents
- class trinity.trainer.Trainer(config: Config)[source]
Bases:
object
Consume the experience and train the model.
- async train_step() bool [source]
Train one step.
- Returns:
Whether to continue training.
- Return type:
bool
- property train_step_num: int
Get the current training step number.
- class trinity.trainer.TrainEngineWrapper[source]
Bases:
ABC
A wrapper class to wrap various training engines.
- abstract property train_step_num: int
Get the current training step number.
- abstract train_step(batch: Experiences) Tuple[bool, Dict] [source]
Training one step.
- Parameters:
batch (Experiences) – A batch of experiences to train.
- Returns:
Whether to continue training. Dict: Metrics of the training step.
- Return type:
bool
- trinity.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper [source]
Get a trainer wrapper.