trinity.trainer.verl

Submodules

trinity.trainer.verl.dp_actor module

Single Process Actor. Modified from https://github.com/volcengine/verl/blob/v0.4.1/verl/workers/actor/dp_actor.py

class trinity.trainer.verl.dp_actor.DataParallelPPOActor(config, actor_module: Module, actor_optimizer: Optimizer | None = None)[source]

Bases: DataParallelPPOActor

__init__(config, actor_module: Module, actor_optimizer: Optimizer | None = None)[source]

When optimizer is None, it is Reference Policy

set_algorithm(algorithm_config: AlgorithmConfig)[source]
update_policy(**kwargs)

Update the policy with an iterator of DataProto

Parameters:

data (DataProto) – an iterator over the DataProto that returns by `make_minibatch_iterator`

Returns:

a dictionary contains anything. Typically, it contains the statistics during updating the model such as `loss`, `grad_norm`, etc,.

Return type:

Dict

trinity.trainer.verl.fsdp_checkpoint_manager module

class trinity.trainer.verl.fsdp_checkpoint_manager.FSDPCheckpointManager(*args, **kwargs)[source]

Bases: FSDPCheckpointManager

An enhanced version of the original FSDP checkpoint manager that:

  1. Uploads model state dicts to a remote Synchronizer actor (either directly or via checkpoints).

  2. Offloads saving operations (model, optimizer, extra states) into background threads to avoid blocking the training loop.

This class is useful in distributed training scenarios where synchronization and non-blocking I/O are important.

__init__(*args, **kwargs)[source]
upload_state_dict(global_step: int)[source]

Uploads the full model state dictionary to the synchronizer actor for remote access.

Parameters:

global_step (int) – The current training step number.

save_checkpoint(local_path: str, hdfs_path: str | None = None, global_step: int = 0, max_ckpt_to_keep: int | None = None, model_state_dict_only: bool = False)[source]

Modified from verl.utils.checkpoint.fsdp_checkpoint_manager.py:save_checkpoint

Saves the model checkpoint to disk, optionally uploads it to a remote Synchronizer, and uses background threads to prevent blocking the main training loop.

Main improvements over the base class: - Uses separate threads for saving model/optimizer/extras. - Implements synchronization with a remote actor. If the model is not trained (global_step == 0) or continues from a breakpoint, Synchonizer will be notified and the model will not be saved.

Parameters:
  • local_path (str) – Local directory path to save the checkpoint.

  • hdfs_path (str, optional) – HDFS path for saving the checkpoint (not implemented here).

  • global_step (int) – Current training step.

  • max_ckpt_to_keep (int, optional) – Maximum number of checkpoints to keep locally.

  • model_state_dict_only (bool) – Whether to only save the model state dict (no optimizer, etc.).

wait_on_save_thread() None[source]

Wait for all background saving threads to complete.

trinity.trainer.verl.fsdp_workers module

The main entry point to run the PPO algorithm. Modified from https://github.com/volcengine/verl/blob/v0.4.1/verl/workers/fsdp_workers.py

class trinity.trainer.verl.fsdp_workers.ActorRolloutRefWorker(*args, **kwargs)[source]

Bases: Worker

This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy or a hybrid engine based on the config.rollout

__init__(config: DictConfig, role: str)[source]

Initialize the worker with environment settings and device configuration.

Parameters:

cuda_visible_devices (str, optional) – CUDA visible devices configuration. Defaults to None.

init_model()[source]
setup_weight_sync_group()[source]
sync_weight()[source]
upload_state_dict(trainer_step: int)[source]
set_algorithm(algo_config: AlgorithmConfig)[source]
update_actor(data: DataProto)[source]
compute_log_prob(data: DataProto)[source]
compute_ref_log_prob(data: DataProto)[source]
save_checkpoint(local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None, model_state_dict_only=False)[source]
load_checkpoint(local_path, hdfs_path=None, del_local_after_load=False)[source]
clear_optimizer_state()[source]
wait_on_save_thread() None[source]
class trinity.trainer.verl.fsdp_workers.CriticWorker(*args, **kwargs)[source]

Bases: Worker

__init__(config)[source]

Initialize the worker with environment settings and device configuration.

Parameters:

cuda_visible_devices (str, optional) – CUDA visible devices configuration. Defaults to None.

init_model()[source]
compute_values(data: DataProto)[source]
update_critic(data: DataProto)[source]
save_checkpoint(local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None)[source]
load_checkpoint(local_path, hdfs_path=None, del_local_after_load=True)[source]
clear_optimizer_state()[source]
wait_on_save_thread() None[source]

trinity.trainer.verl.utils module

Utils for ccompatibility issues with verl.

trinity.trainer.verl.utils.to_data_proto(experiences: Experiences) DataProto[source]

Convert Experiences to verl DataProto.

trinity.trainer.verl.utils.compute_data_metrics(batch: DataProto, use_critic: bool = False) dict[source]

Computes various metrics from a batch of data for PPO training. Modified from verl.trainer.ppo.metric_utils.compute_data_metrics

This function calculates metrics related to scores, rewards, advantages, returns, values, and sequence lengths from a batch of data. It provides statistical information (mean, max, min) for each metric category.

Parameters:
  • batch – A DataProto object containing batch data with token-level scores, rewards, advantages, etc.

  • use_critic – Whether to include critic-specific metrics. Defaults to True.

Returns:

  • critic/score/mean, max, min: Statistics about sequence scores

  • critic/rewards/mean, max, min: Statistics about sequence rewards

  • critic/advantages/mean, max, min: Statistics about advantages

  • critic/returns/mean, max, min: Statistics about returns

  • critic/values/mean, max, min: Statistics about critic values (if use_critic=True)

  • critic/vf_explained_var: Explained variance of the value function (if use_critic=True)

  • response_length/mean, max, min, clip_ratio: Statistics about response lengths

  • prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths

Return type:

A dictionary of metrics including

Module contents