trinity.trainer
Subpackages
- trinity.trainer.verl
- Submodules
- trinity.trainer.verl.core_algos module
KLController
AdaptiveKLController
FixedKLController
get_kl_controller()
compute_opmd_outcome_advantage()
compute_gae_advantage_return()
compute_grpo_outcome_advantage()
compute_rloo_outcome_advantage()
compute_reinforce_plus_plus_outcome_advantage()
compute_remax_outcome_advantage()
compute_rewards()
compute_policy_loss()
compute_policy_loss_dpo()
compute_policy_loss_pairwise_opmd()
compute_policy_loss_opmd()
compute_policy_loss_ppo()
compute_policy_loss_sft()
compute_entropy_loss()
compute_value_loss()
kl_penalty()
- trinity.trainer.verl.dp_actor module
- trinity.trainer.verl.fsdp_workers module
create_device_mesh()
get_sharding_strategy()
ActorRolloutRefWorker
ActorRolloutRefWorker.__init__()
ActorRolloutRefWorker.init_model()
ActorRolloutRefWorker.setup_weight_sync_group()
ActorRolloutRefWorker.sync_weight()
ActorRolloutRefWorker.set_mode()
ActorRolloutRefWorker.update_actor()
ActorRolloutRefWorker.generate_sequences()
ActorRolloutRefWorker.compute_log_prob()
ActorRolloutRefWorker.compute_ref_log_prob()
ActorRolloutRefWorker.save_checkpoint()
ActorRolloutRefWorker.load_checkpoint()
ActorRolloutRefWorker.clear_optimizer_state()
CriticWorker
RewardModelWorker
- trinity.trainer.verl.ray_trainer module
- Module contents
Submodules
trinity.trainer.trainer module
Trainer Class This file is modified from verl.trainer.main_ppo.py And is a reproduction code of Jiayi-Pan/TinyZero.
Note that we don’t combine the main with ray_trainer as ray_trainer is used by other main.
- class trinity.trainer.trainer.TrainEngineWrapper[source]
Bases:
ABC
A wrapper class to wrap various training engines.
- abstract set_mode(algo_type: AlgorithmType) None [source]
Set training mode.
- trinity.trainer.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper [source]
Get a trainer wrapper.
trinity.trainer.verl_trainer module
veRL Trainer Class
Modified from verl/trainer/ppo/ray_trainer.py
- class trinity.trainer.verl_trainer.VerlPPOTrainerWrapper(global_config: Config)[source]
Bases:
RayPPOTrainer
,TrainEngineWrapper
A wrapper for verl.trainer.ppo.RayPPOTrainer.
- train_dpo_step(experiences: Experiences) Tuple[bool, int] [source]
Train on the DPO data.
- train_sft_step(experiences: Experiences) Tuple[bool, int] [source]
Train on the SFT data.
- train_rft_step(experiences: Experiences) Tuple[bool, int] [source]
Train on the RFT data.
- set_mode(algorithm_type: AlgorithmType = AlgorithmType.PPO) None [source]
Set training mode.
Module contents
- class trinity.trainer.TrainEngineWrapper[source]
Bases:
ABC
A wrapper class to wrap various training engines.
- abstract set_mode(algo_type: AlgorithmType) None [source]
Set training mode.
- trinity.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper [source]
Get a trainer wrapper.