trinity.algorithm.policy_loss_fn.mix_policy_loss module#
Mix policy loss function.
- class trinity.algorithm.policy_loss_fn.mix_policy_loss.MIXPolicyLossFn(backend: str = 'verl', mu: float = 0.1, clip_range: float | None = None, clip_range_low: float | None = None, clip_range_high: float | None = None, use_dynamic_bsz: bool | None = None, ppo_mini_batch_size: int = 1, ppo_micro_batch_size_per_gpu: int = 1, ngpus_trainer: int = 1, train_batch_size_usual: int = 1, train_batch_size_expert: int = 1, loss_agg_mode: str = 'token-mean', sft_loss_agg_mode: str | None = None, grpo_loss_agg_mode: str | None = None)[源代码]#
基类:
PolicyLossFnImplements a mixed policy loss combining GRPO and SFT losses.
This loss function applies different loss components to data based on whether it comes from an expert or not, as indicated by expert_mask. It combines: - GRPO loss (self.grpo_loss_fn) for non-expert data - SFT loss (self.sft_loss_fn) for expert data - Weighting parameter mu
The per-sample weights are normalized using either experience_per_gpu or gradient_accumulation, depending on whether dynamic batch sizing is enabled, to ensure consistent weighting across different batches of the same type experiences.
- __init__(backend: str = 'verl', mu: float = 0.1, clip_range: float | None = None, clip_range_low: float | None = None, clip_range_high: float | None = None, use_dynamic_bsz: bool | None = None, ppo_mini_batch_size: int = 1, ppo_micro_batch_size_per_gpu: int = 1, ngpus_trainer: int = 1, train_batch_size_usual: int = 1, train_batch_size_expert: int = 1, loss_agg_mode: str = 'token-mean', sft_loss_agg_mode: str | None = None, grpo_loss_agg_mode: str | None = None) None[源代码]#
Initialize the policy loss function.
- 参数:
backend -- The training framework/backend to use (e.g., "verl")
- classmethod default_args() Dict[源代码]#
Get default initialization arguments for this loss function.
- 返回:
The default init arguments for the policy loss function.
- 返回类型:
Dict
- property select_keys#
Returns parameter keys mapped to the specific training framework's naming convention.