trinity.trainer.verl_trainer module#

veRL Trainer Class

Modified from verl/trainer/ppo/ray_trainer.py

class trinity.trainer.verl_trainer.CheckpointMonitor(save_strategy: SaveStrategy, default_local_dir: str, default_hdfs_dir: str = None)[源代码]#

基类:object

__init__(save_strategy: SaveStrategy, default_local_dir: str, default_hdfs_dir: str = None)[源代码]#
update_latest_checkpoint_step(step: int)[源代码]#
update_latest_state_dict_step(step: int)[源代码]#
async register_thread_count(step: int, *, state_dict_thread_count: int = 0, checkpoint_thread_count: int = 0)[源代码]#
async monitor_step(step: int, is_state_dict: bool = False)[源代码]#
async notify_started(node_id: str, job_id: str)[源代码]#
async notify_finished(step: int, is_state_dict: bool = False)[源代码]#
classmethod get_actor(namespace: str, save_strategy: SaveStrategy | None = None, default_local_dir: str | None = None, default_hdfs_dir: str | None = None)[源代码]#
class trinity.trainer.verl_trainer.VerlPPOTrainerWrapper(global_config: Config)[源代码]#

基类:RayPPOTrainer, TrainEngineWrapper

A wrapper for verl.trainer.ppo.RayPPOTrainer.

__init__(global_config: Config)[源代码]#

Initialize distributed PPO trainer with Ray backend. Note that this trainer runs on the driver process on a single CPU/GPU node.

参数:
  • config -- Configuration object containing training parameters.

  • tokenizer -- Tokenizer used for encoding and decoding text.

  • role_worker_mapping (dict[Role, WorkerType]) -- Mapping from roles to worker classes.

  • resource_pool_manager (ResourcePoolManager) -- Manager for Ray resource pools.

  • ray_worker_group_cls (RayWorkerGroup, optional) -- Class for Ray worker groups. Defaults to RayWorkerGroup.

  • processor -- Optional data processor, used for multimodal data

  • reward_fn -- Function for computing rewards during training.

  • val_reward_fn -- Function for computing rewards during validation.

  • train_dataset (Optional[Dataset], optional) -- Training dataset. Defaults to None.

  • val_dataset (Optional[Dataset], optional) -- Validation dataset. Defaults to None.

  • collate_fn -- Function to collate data samples into batches.

  • train_sampler (Optional[Sampler], optional) -- Sampler for the training dataset. Defaults to None.

  • device_name (str, optional) -- Device name for training (e.g., "cuda", "cpu"). Defaults to None.

init_workers()[源代码]#

Initialize distributed training workers using Ray backend.

Creates:

  1. Ray resource pools from configuration

  2. Worker groups for each role (actor, critic, etc.)

property train_step_num: int#

Get the current training step number.

async prepare()[源代码]#

Do some preparation before training started.

save_state_dict()[源代码]#

Only save the model state dict for Synchronizer.

upload_state_dict()[源代码]#

Upload the state dict to Synchronizer.

async train_step(batch_exps: List[Experience]) Dict[源代码]#

Training one step.

参数:

batch_exps (List[Experience]) -- A batch of experiences to train.

返回:

Metrics of the training step.

返回类型:

Dict

save_checkpoint(block_until_saved: bool = False, save_as_hf: bool = False) None[源代码]#

Save the checkpoint.

sync_weight() None[源代码]#

Sync the model weight.

post_process_batch(batch: DataProto) DataProto[源代码]#

Adapted from verl/utils/dataset/rl_dataset.py