Source code for trinity.common.verl_config

import math
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from omegaconf import OmegaConf

from trinity.common.config import BufferConfig, Config, SynchronizerConfig
from trinity.common.constants import AlgorithmType
from trinity.trainer.verl.ray_trainer import AdvantageEstimator
from trinity.utils.log import get_logger

logger = get_logger(__name__)


[docs] @dataclass class Data: train_batch_size: int = 1024
[docs] @dataclass class ActorModel: path: str = "" external_lib: Optional[str] = None override_config: Dict[str, Any] = field(default_factory=dict) enable_gradient_checkpointing: bool = True use_remove_padding: bool = False
[docs] @dataclass class Optim: lr: float = 1e-6 lr_warmup_steps: int = -1 lr_warmup_steps_ratio: float = 0.0 min_lr_ratio: Optional[float] = 0.0 warmup_style: str = "constant" total_training_steps: int = -1 beta1: float = 0.9 beta2: float = 0.999
[docs] @dataclass class WrapPolicy: min_num_params: int = 0
[docs] @dataclass class FSDPConfig: wrap_policy: WrapPolicy = field(default_factory=WrapPolicy) min_num_params: int = 0 param_offload: bool = False optimizer_offload: bool = False fsdp_size: int = -1
[docs] @dataclass class Checkpoint: contents: List[str] = field(default_factory=lambda: ["model", "hf_model", "optimizer", "extra"])
[docs] @dataclass class Actor: strategy: str = "fsdp" ppo_mini_batch_size: int = 256 ppo_micro_batch_size: Optional[int] = None ppo_micro_batch_size_per_gpu: int = 1 use_dynamic_bsz: bool = False ppo_max_token_len_per_gpu: int = ( 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} ) grad_clip: float = 1.0 clip_ratio: float = 0.2 entropy_coeff: float = 0.001 use_kl_loss: bool = False kl_loss_coef: float = 0.001 kl_loss_type: str = "low_var_kl" ppo_epochs: int = 1 shuffle: bool = False ulysses_sequence_parallel_size: int = 1 checkpoint: Checkpoint = field(default_factory=Checkpoint) optim: Optim = field(default_factory=Optim) fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) algorithm_type: AlgorithmType = AlgorithmType.PPO tau: float = 0.001 # strength of regularization w.r.t. old / ref policy opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd use_uid: bool = False # True / False, applicable to pairwise_opmd
[docs] @dataclass class Ref: fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) log_prob_micro_batch_size: Optional[int] = None log_prob_micro_batch_size_per_gpu: int = 1 log_prob_use_dynamic_bsz: bool = False log_prob_max_token_len_per_gpu: int = 0 ulysses_sequence_parallel_size: int = 1
[docs] @dataclass class Rollout: temperature: float = 1.0 n: int = 1 # > 1 for grpo
[docs] @dataclass class ActorRolloutRef: hybrid_engine: bool = True model: ActorModel = field(default_factory=ActorModel) actor: Actor = field(default_factory=Actor) ref: Ref = field(default_factory=Ref) rollout: Rollout = field(default_factory=Rollout) synchronizer: Optional[SynchronizerConfig] = None
[docs] @dataclass class CriticModel: path: str = "" tokenizer_path: str = "" override_config: Dict[str, str] = field(default_factory=dict) external_lib: Optional[str] = None enable_gradient_checkpointing: bool = True use_remove_padding: bool = False fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
[docs] @dataclass class Critic: strategy: str = "fsdp" optim: Optim = field(default_factory=Optim) model: CriticModel = field(default_factory=CriticModel) ppo_mini_batch_size: int = 0 ppo_micro_batch_size: Optional[int] = None ppo_micro_batch_size_per_gpu: int = 1 forward_micro_batch_size: Optional[int] = None forward_micro_batch_size_per_gpu: Optional[int] = None use_dynamic_bsz: bool = False ppo_max_token_len_per_gpu: int = 0 forward_max_token_len_per_gpu: int = 0 ulysses_sequence_parallel_size: int = 1 ppo_epochs: int = 0 shuffle: bool = False grad_clip: float = 0.0 cliprange_value: float = 0.0 checkpoint: Checkpoint = field(default_factory=Checkpoint) rollout_n: int = 1
@dataclass class _RewardModel: input_tokenizer: Optional[str] = None path: str = "" external_lib: Optional[str] = None use_remove_padding: bool = False fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
[docs] @dataclass class RewardModel: enable: bool = False strategy: str = "fsdp" model: _RewardModel = field(default_factory=_RewardModel) micro_batch_size_per_gpu: int = 1 max_length: Optional[int] = None ulysses_sequence_parallel_size: int = 1 use_dynamic_bsz: bool = False forward_max_token_len_per_gpu: int = 0 reward_manager: str = "naive"
[docs] @dataclass class CustomRewardFunction: path: Optional[str] = None name: str = "compute_score"
[docs] @dataclass class KL_Ctrl: type: str = "fixed" kl_coef: float = 0.001 horizon: float = 10000 target_kl: float = 0.1
[docs] @dataclass class Algorithm: gamma: float = 1.0 lam: float = 1.0 adv_estimator: str = "gae" norm_adv_by_std_in_grpo: bool = True use_kl_in_reward: bool = False kl_penalty: str = "kl" kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl)
[docs] @dataclass class Trainer: balance_batch: bool = True total_epochs: int = 30 total_training_steps: Optional[int] = None project_name: str = "" experiment_name: str = "" logger: List[str] = field(default_factory=list) val_generations_to_log_to_wandb: int = 0 nnodes: int = 0 n_gpus_per_node: int = 0 save_freq: int = 0 resume_mode: str = "auto" resume_from_path: str = "" test_freq: int = 0 critic_warmup: int = 0 default_hdfs_dir: Optional[str] = None remove_previous_ckpt_in_save: bool = False # deprecated del_local_ckpt_after_load: bool = False default_local_dir: str = "" val_before_train: bool = False training_rollout_mode: str = "parallel" enable_exp_buffer: bool = True sync_freq: int = 0 sft_warmup_steps: int = 0 max_actor_ckpt_to_keep: Optional[int] = None max_critic_ckpt_to_keep: Optional[int] = None
[docs] @dataclass class veRLConfig: data: Data = field(default_factory=Data) actor_rollout_ref: ActorRolloutRef = field(default_factory=ActorRolloutRef) critic: Critic = field(default_factory=Critic) reward_model: RewardModel = field(default_factory=RewardModel) custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction) algorithm: Algorithm = field(default_factory=Algorithm) trainer: Trainer = field(default_factory=Trainer) buffer: BufferConfig = field(default_factory=BufferConfig) synchronizer: Optional[SynchronizerConfig] = None enable_preview: bool = True
[docs] def synchronize_config(self, config: Config) -> None: # noqa: C901 """Synchronize config.""" if config.mode != "train": rollout_gpu_num = ( config.explorer.rollout_model.tensor_parallel_size * config.explorer.rollout_model.engine_num + sum( [ model.tensor_parallel_size * model.engine_num for model in config.explorer.auxiliary_models ] ) ) else: rollout_gpu_num = 0 if config.cluster.node_num == 1: # for single node scenarios, rollout and training are on the same node self.trainer.nnodes = config.cluster.node_num self.trainer.n_gpus_per_node = config.cluster.gpu_per_node - rollout_gpu_num else: # for multi-node scenarios, some nodes for rollout, others for training assert ( rollout_gpu_num % config.cluster.gpu_per_node == 0 ), "rollout_gpu_num must be divisible by `gpu_per_node`" rollout_node_num = math.ceil(rollout_gpu_num / config.cluster.gpu_per_node) self.trainer.nnodes = config.cluster.node_num - rollout_node_num if self.trainer.nnodes < 1: raise ValueError("The number of training nodes must be greater than 0") self.trainer.n_gpus_per_node = config.cluster.gpu_per_node world_size = self.trainer.nnodes * self.trainer.n_gpus_per_node if config.buffer.batch_size % world_size != 0: raise ValueError( f"batch_size ({config.buffer.batch_size}) must be divisible by ({world_size})" ) self.trainer.sync_freq = config.synchronizer.sync_interval self.trainer.save_freq = config.trainer.save_interval self.trainer.project_name = config.project self.trainer.experiment_name = config.name self.trainer.default_local_dir = config.checkpoint_job_dir self.trainer.sft_warmup_steps = config.buffer.trainer_input.sft_warmup_steps self.buffer = config.buffer # TODO: use dynamic read_batch_size to support multi-round scenarios # Get the experiences of one explore step self.data.train_batch_size = config.buffer.batch_size self.synchronizer = config.synchronizer self.actor_rollout_ref.synchronizer = config.synchronizer # Actor / Critic config self.actor_rollout_ref.model.path = config.model.model_path self.critic.model.path = config.model.critic_model_path self.critic.model.tokenizer_path = config.model.critic_model_path self.actor_rollout_ref.actor.ppo_mini_batch_size = ( config.buffer.batch_size ) # TODO: may allow user to change self.actor_rollout_ref.rollout.temperature = ( config.buffer.explorer_input.taskset.rollout_args.temperature ) self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times self.critic.ppo_mini_batch_size = config.buffer.batch_size self.critic.rollout_n = self.actor_rollout_ref.rollout.n if config.trainer.actor_use_kl_loss is not None: self.actor_rollout_ref.actor.use_kl_loss = config.trainer.actor_use_kl_loss if config.trainer.actor_kl_loss_coef is not None: self.actor_rollout_ref.actor.kl_loss_coef = config.trainer.actor_kl_loss_coef if config.trainer.actor_entropy_coef is not None: self.actor_rollout_ref.actor.entropy_coeff = config.trainer.actor_entropy_coef if config.trainer.actor_grad_clip is not None: self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip if config.trainer.actor_clip_ratio is not None: self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio # Algorithm related config if config.algorithm.gamma is not None: self.algorithm.gamma = config.algorithm.gamma if config.algorithm.lam is not None: self.algorithm.lam = config.algorithm.lam self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type if config.algorithm.algorithm_type == AlgorithmType.PPO: logger.info("Using GAE `adv_estimator` for PPO") self.algorithm.adv_estimator = AdvantageEstimator.GAE.value elif config.algorithm.algorithm_type == AlgorithmType.GRPO: logger.info("Using GRPO `adv_estimator` for GRPO") self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO if not self.actor_rollout_ref.actor.use_kl_loss: self.actor_rollout_ref.actor.use_kl_loss = True logger.warning("DPO must use KL loss.") logger.warning("DPO micro batch size is doubled for computing loss.") self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 # type: ignore self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 # type: ignore if self.actor_rollout_ref.rollout.n != 2: self.actor_rollout_ref.rollout.n = 2 # TODO: check other fields self.enable_preview = config.trainer.enable_preview
[docs] def load_config(config_path: str) -> veRLConfig: schema = OmegaConf.structured(veRLConfig) yaml_config = OmegaConf.load(config_path) try: config = OmegaConf.merge(schema, yaml_config) return OmegaConf.to_object(config) except Exception as e: raise ValueError(f"Invalid configuration: {e}") from e