trinity.trainer

Subpackages

Submodules

trinity.trainer.trainer module

Trainer Class This file is modified from verl.trainer.main_ppo.py And is a reproduction code of Jiayi-Pan/TinyZero.

Note that we don’t combine the main with ray_trainer as ray_trainer is used by other main.

class trinity.trainer.trainer.TrainEngineWrapper[source]

Bases: ABC

A wrapper class to wrap various training engines.

abstract prepare() None[source]

Do some preparation before training started.

abstract train_rft_step(experiences) Tuple[bool, int][source]

Train on the RFT data.

abstract train_sft_step(experiences) Tuple[bool, int][source]

Train on the SFT data.

abstract train_dpo_step(experiences) Tuple[bool, int][source]

Train on the DPO data.

abstract save_checkpoint() None[source]

Save the checkpoint.

abstract sync_weight() None[source]

Sync the model weight.

abstract set_mode(algo_type: AlgorithmType) None[source]

Set training mode.

abstract shutdown() None[source]

Shutdown the engine.

trinity.trainer.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]

Get a trainer wrapper.

trinity.trainer.verl_trainer module

veRL Trainer Class

Modified from verl/trainer/ppo/ray_trainer.py

class trinity.trainer.verl_trainer.VerlPPOTrainerWrapper(global_config: Config)[source]

Bases: RayPPOTrainer, TrainEngineWrapper

A wrapper for verl.trainer.ppo.RayPPOTrainer.

__init__(global_config: Config)[source]
reset_experiences_example_table()[source]
prepare()[source]

Do some preparation before training started.

train_dpo_step(experiences: Experiences) Tuple[bool, int][source]

Train on the DPO data.

train_sft_step(experiences: Experiences) Tuple[bool, int][source]

Train on the SFT data.

train_rft_step(experiences: Experiences) Tuple[bool, int][source]

Train on the RFT data.

save_checkpoint() None[source]

Save the checkpoint.

sync_weight() None[source]

Sync the model weight.

set_mode(algorithm_type: AlgorithmType = AlgorithmType.PPO) None[source]

Set training mode.

sft_to_rft() None[source]
shutdown() None[source]

Shutdown the engine.

Module contents

class trinity.trainer.TrainEngineWrapper[source]

Bases: ABC

A wrapper class to wrap various training engines.

abstract prepare() None[source]

Do some preparation before training started.

abstract train_rft_step(experiences) Tuple[bool, int][source]

Train on the RFT data.

abstract train_sft_step(experiences) Tuple[bool, int][source]

Train on the SFT data.

abstract train_dpo_step(experiences) Tuple[bool, int][source]

Train on the DPO data.

abstract save_checkpoint() None[source]

Save the checkpoint.

abstract sync_weight() None[source]

Sync the model weight.

abstract set_mode(algo_type: AlgorithmType) None[source]

Set training mode.

abstract shutdown() None[source]

Shutdown the engine.

trinity.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]

Get a trainer wrapper.