trinity.trainer.verl_trainer module#
veRL Trainer Class
Modified from verl/trainer/ppo/ray_trainer.py
- class trinity.trainer.verl_trainer.CheckpointMonitor(save_strategy: SaveStrategy, default_local_dir: str, default_hdfs_dir: str = None)[源代码]#
基类:
object- __init__(save_strategy: SaveStrategy, default_local_dir: str, default_hdfs_dir: str = None)[源代码]#
- async register_thread_count(step: int, *, state_dict_thread_count: int = 0, checkpoint_thread_count: int = 0)[源代码]#
- classmethod get_actor(namespace: str, save_strategy: SaveStrategy | None = None, default_local_dir: str | None = None, default_hdfs_dir: str | None = None)[源代码]#
- class trinity.trainer.verl_trainer.VerlPPOTrainerWrapper(global_config: Config)[源代码]#
基类:
RayPPOTrainer,TrainEngineWrapperA wrapper for verl.trainer.ppo.RayPPOTrainer.
- __init__(global_config: Config)[源代码]#
Initialize distributed PPO trainer with Ray backend. Note that this trainer runs on the driver process on a single CPU/GPU node.
- 参数:
config -- Configuration object containing training parameters.
tokenizer -- Tokenizer used for encoding and decoding text.
role_worker_mapping (dict[Role, WorkerType]) -- Mapping from roles to worker classes.
resource_pool_manager (ResourcePoolManager) -- Manager for Ray resource pools.
ray_worker_group_cls (RayWorkerGroup, optional) -- Class for Ray worker groups. Defaults to RayWorkerGroup.
processor -- Optional data processor, used for multimodal data
reward_fn -- Function for computing rewards during training.
val_reward_fn -- Function for computing rewards during validation.
train_dataset (Optional[Dataset], optional) -- Training dataset. Defaults to None.
val_dataset (Optional[Dataset], optional) -- Validation dataset. Defaults to None.
collate_fn -- Function to collate data samples into batches.
train_sampler (Optional[Sampler], optional) -- Sampler for the training dataset. Defaults to None.
device_name (str, optional) -- Device name for training (e.g., "cuda", "cpu"). Defaults to None.
- init_workers()[源代码]#
Initialize distributed training workers using Ray backend.
Creates:
Ray resource pools from configuration
Worker groups for each role (actor, critic, etc.)
- property train_step_num: int#
Get the current training step number.
- async train_step(batch_exps: List[Experience]) Dict[源代码]#
Training one step.
- 参数:
batch_exps (List[Experience]) -- A batch of experiences to train.
- 返回:
Metrics of the training step.
- 返回类型:
Dict