trinity.trainer

Subpackages

Submodules

trinity.trainer.trainer module

Trainer Class

class trinity.trainer.trainer.Trainer(config: Config)[source]

Bases: object

Consume the experience and train the model.

__init__(config: Config) None[source]
prepare() None[source]

Prepare the trainer.

train() str[source]

Train the model.

train_step() bool[source]

Train one step.

Returns:

Whether to continue training.

Return type:

bool

need_sync() bool[source]

Whether to sync the model weight.

sync_weight() None[source]

Sync the model weight.

flush_log(step: int) None[source]

Flush the log of the current step.

shutdown() None[source]
class trinity.trainer.trainer.TrainEngineWrapper[source]

Bases: ABC

A wrapper class to wrap various training engines.

abstract prepare() None[source]

Do some preparation before training started.

abstract property train_step_num: int

Get the current training step number.

abstract train_step() bool[source]

Training.

abstract save_checkpoint() None[source]

Save the checkpoint.

abstract sync_weight() None[source]

Sync the model weight.

abstract shutdown() None[source]

Shutdown the engine.

trinity.trainer.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]

Get a trainer wrapper.

trinity.trainer.verl_trainer module

veRL Trainer Class

Modified from verl/trainer/ppo/ray_trainer.py

class trinity.trainer.verl_trainer.VerlPPOTrainerWrapper(global_config: Config)[source]

Bases: RayPPOTrainer, TrainEngineWrapper

A wrapper for verl.trainer.ppo.RayPPOTrainer.

__init__(global_config: Config)[source]

Initialize distributed PPO trainer with Ray backend.

init_workers()[source]

Initialize distributed training workers using Ray backend.

Creates:

  1. Ray resource pools from configuration

  2. Worker groups for each role (actor, critic, etc.)

reset_experiences_example_table()[source]
property train_step_num: int

Get the current training step number.

prepare()[source]

Do some preparation before training started.

train_step() bool[source]

Training.

save_checkpoint() None[source]

Save the checkpoint.

sync_weight() None[source]

Sync the model weight.

sft_to_rft() None[source]
shutdown() None[source]

Shutdown the engine.

Module contents

class trinity.trainer.Trainer(config: Config)[source]

Bases: object

Consume the experience and train the model.

__init__(config: Config) None[source]
prepare() None[source]

Prepare the trainer.

train() str[source]

Train the model.

train_step() bool[source]

Train one step.

Returns:

Whether to continue training.

Return type:

bool

need_sync() bool[source]

Whether to sync the model weight.

sync_weight() None[source]

Sync the model weight.

flush_log(step: int) None[source]

Flush the log of the current step.

shutdown() None[source]
class trinity.trainer.TrainEngineWrapper[source]

Bases: ABC

A wrapper class to wrap various training engines.

abstract prepare() None[source]

Do some preparation before training started.

abstract property train_step_num: int

Get the current training step number.

abstract train_step() bool[source]

Training.

abstract save_checkpoint() None[source]

Save the checkpoint.

abstract sync_weight() None[source]

Sync the model weight.

abstract shutdown() None[source]

Shutdown the engine.

trinity.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]

Get a trainer wrapper.