trinity.trainer package
Subpackages
Submodules
- trinity.trainer.trainer module
- trinity.trainer.verl_trainer module
VerlPPOTrainerWrapper
VerlPPOTrainerWrapper.__init__()
VerlPPOTrainerWrapper.init_workers()
VerlPPOTrainerWrapper.train_step_num
VerlPPOTrainerWrapper.prepare()
VerlPPOTrainerWrapper.save_state_dict()
VerlPPOTrainerWrapper.upload_state_dict()
VerlPPOTrainerWrapper.train_step()
VerlPPOTrainerWrapper.save_checkpoint()
VerlPPOTrainerWrapper.sync_weight()
VerlPPOTrainerWrapper.sft_to_rft()
Module contents
- class trinity.trainer.Trainer(config: Config)[source]
Bases:
object
Consume the experience and train the model.
- async train_step() bool [source]
Train one step.
- Returns:
Whether to continue training.
- Return type:
bool
- property train_step_num: int
Get the current training step number.
- class trinity.trainer.TrainEngineWrapper[source]
Bases:
ABC
A wrapper class to wrap various training engines.
- abstract property train_step_num: int
Get the current training step number.
- abstract train_step(batch: Experiences) Tuple[bool, Dict] [source]
Training one step.
- Parameters:
batch (Experiences) – A batch of experiences to train.
- Returns:
Whether to continue training. Dict: Metrics of the training step.
- Return type:
bool
- trinity.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper [source]
Get a trainer wrapper.