trinity.trainer

Subpackages

Submodules

trinity.trainer.trainer module

Trainer Class

class trinity.trainer.trainer.Trainer(config: Config)[source]

Bases: object

Consume the experience and train the model.

__init__(config: Config) None[source]
prepare() None[source]

Prepare the trainer.

async train() str[source]

Train the model.

async train_step() bool[source]

Train one step.

Returns:

Whether to continue training.

Return type:

bool

need_sync() bool[source]

Whether to sync the model weight.

sync_weight() None[source]

Sync the model weight.

async shutdown() None[source]
property train_step_num: int

Get the current training step number.

is_alive() bool[source]

Check if the trainer is alive.

class trinity.trainer.trainer.TrainEngineWrapper[source]

Bases: ABC

A wrapper class to wrap various training engines.

abstract prepare() None[source]

Do some preparation before training started.

abstract property train_step_num: int

Get the current training step number.

abstract train_step(batch: Experiences) Tuple[bool, Dict][source]

Training one step.

Parameters:

batch (Experiences) – A batch of experiences to train.

Returns:

Whether to continue training. Dict: Metrics of the training step.

Return type:

bool

abstract save_checkpoint(block_until_saved: bool = False) None[source]

Save the checkpoint.

abstract sync_weight() None[source]

Sync the model weight.

abstract upload_state_dict() None[source]

Upload the state dict to Synchronizer.

abstract save_state_dict() None[source]

Only save the model state dict for Synchronizer.

trinity.trainer.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]

Get a trainer wrapper.

trinity.trainer.verl_trainer module

veRL Trainer Class

Modified from verl/trainer/ppo/ray_trainer.py

class trinity.trainer.verl_trainer.VerlPPOTrainerWrapper(global_config: Config)[source]

Bases: RayPPOTrainer, TrainEngineWrapper

A wrapper for verl.trainer.ppo.RayPPOTrainer.

__init__(global_config: Config)[source]

Initialize distributed PPO trainer with Ray backend. Note that this trainer runs on the driver process on a single CPU/GPU node.

Parameters:
  • config – Configuration object containing training parameters.

  • tokenizer – Tokenizer used for encoding and decoding text.

  • role_worker_mapping (dict[Role, WorkerType]) – Mapping from roles to worker classes.

  • resource_pool_manager (ResourcePoolManager) – Manager for Ray resource pools.

  • ray_worker_group_cls (RayWorkerGroup, optional) – Class for Ray worker groups. Defaults to RayWorkerGroup.

  • processor – Optional data processor, used for multimodal data

  • reward_fn – Function for computing rewards during training.

  • val_reward_fn – Function for computing rewards during validation.

  • train_dataset (Optional[Dataset], optional) – Training dataset. Defaults to None.

  • val_dataset (Optional[Dataset], optional) – Validation dataset. Defaults to None.

  • collate_fn – Function to collate data samples into batches.

  • train_sampler (Optional[Sampler], optional) – Sampler for the training dataset. Defaults to None.

  • device_name (str, optional) – Device name for training (e.g., “cuda”, “cpu”). Defaults to “cuda”.

init_workers()[source]

Initialize distributed training workers using Ray backend.

Creates:

  1. Ray resource pools from configuration

  2. Worker groups for each role (actor, critic, etc.)

property train_step_num: int

Get the current training step number.

prepare()[source]

Do some preparation before training started.

save_state_dict()[source]

Only save the model state dict for Synchronizer.

upload_state_dict()[source]

Upload the state dict to Synchronizer.

train_step(batch: Experiences) Tuple[bool, Dict][source]

Training one step.

Parameters:

batch (Experiences) – A batch of experiences to train.

Returns:

Whether to continue training. Dict: Metrics of the training step.

Return type:

bool

save_checkpoint(block_until_saved: bool = False) None[source]

Save the checkpoint.

sync_weight() None[source]

Sync the model weight.

sft_to_rft() None[source]

Module contents

class trinity.trainer.Trainer(config: Config)[source]

Bases: object

Consume the experience and train the model.

__init__(config: Config) None[source]
prepare() None[source]

Prepare the trainer.

async train() str[source]

Train the model.

async train_step() bool[source]

Train one step.

Returns:

Whether to continue training.

Return type:

bool

need_sync() bool[source]

Whether to sync the model weight.

sync_weight() None[source]

Sync the model weight.

async shutdown() None[source]
property train_step_num: int

Get the current training step number.

is_alive() bool[source]

Check if the trainer is alive.

class trinity.trainer.TrainEngineWrapper[source]

Bases: ABC

A wrapper class to wrap various training engines.

abstract prepare() None[source]

Do some preparation before training started.

abstract property train_step_num: int

Get the current training step number.

abstract train_step(batch: Experiences) Tuple[bool, Dict][source]

Training one step.

Parameters:

batch (Experiences) – A batch of experiences to train.

Returns:

Whether to continue training. Dict: Metrics of the training step.

Return type:

bool

abstract save_checkpoint(block_until_saved: bool = False) None[source]

Save the checkpoint.

abstract sync_weight() None[source]

Sync the model weight.

abstract upload_state_dict() None[source]

Upload the state dict to Synchronizer.

abstract save_state_dict() None[source]

Only save the model state dict for Synchronizer.

trinity.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]

Get a trainer wrapper.