Source code for trinity.algorithm.policy_loss_fn.chord_policy_loss

"""Implements the CHORD policy loss function."""

import math
from typing import Dict, Optional, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
from trinity.algorithm.utils import masked_mean


[docs] def mu_schedule_function( global_step: int, mu_warmup_steps: int, mu_decay_steps: int, mu_peak: float, mu_valley: float ) -> float: """ Computes a cosine decay schedule with a warmup phase for the mu parameter. """ # Warmup if global_step < mu_warmup_steps: return (global_step / mu_warmup_steps) * mu_peak # Decay if global_step >= (mu_warmup_steps + mu_decay_steps): return mu_valley adjusted_step = global_step - mu_warmup_steps cosine_decay = 0.5 * (1 + math.cos(math.pi * adjusted_step / mu_decay_steps)) decayed_mu = (mu_peak - mu_valley) * cosine_decay + mu_valley return decayed_mu
[docs] @POLICY_LOSS_FN.register_module("sft_is") class SFTISLossFn(PolicyLossFn): """ SFT loss with importance sampling """
[docs] def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None: super().__init__(backend=backend) self.use_token_level_loss = use_token_level_loss
def __call__( # type: ignore self, logprob: torch.Tensor, action_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: token_prob = torch.exp(logprob) if self.use_token_level_loss: sft_loss = masked_mean(-logprob * token_prob.detach(), action_mask) else: sft_loss = masked_mean(-logprob * token_prob.detach(), action_mask, axis=1).mean() return sft_loss, {"sft_is_loss": sft_loss.detach().item()}
[docs] @classmethod def default_args(cls): return { "use_token_level_loss": True, }
[docs] def phi_function(token_prob): """ The phi function downweights token with extreme probability. Feel free to modify this function. """ return token_prob * (1 - token_prob)
[docs] @POLICY_LOSS_FN.register_module("sft_phi") class SFTPhiLossFn(PolicyLossFn): """ SFT loss with transformed phi function """
[docs] def __init__( self, backend: str = "verl", use_token_level_loss: bool = True, cutoff_prob: float = 1.0 ) -> None: super().__init__(backend=backend) self.use_token_level_loss = use_token_level_loss self.cutoff_prob = cutoff_prob assert 0.0 <= self.cutoff_prob <= 1.0
def __call__( # type: ignore self, logprob: torch.Tensor, action_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: token_prob = torch.exp(logprob) if self.cutoff_prob < 1.0: logprob = torch.clamp(logprob, max=math.log(self.cutoff_prob)) weighted_phi = phi_function(token_prob) if self.use_token_level_loss: sft_loss = masked_mean(-logprob * weighted_phi.detach(), action_mask) else: sft_loss = masked_mean(-logprob * weighted_phi.detach(), action_mask, axis=1).mean() return sft_loss, {"sft_phi_loss": sft_loss.detach().item()}
[docs] @classmethod def default_args(cls): return { "use_token_level_loss": True, "cutoff_prob": 1.0, }
[docs] @POLICY_LOSS_FN.register_module("mix_chord") class MIXCHORDPolicyLossFn(PolicyLossFn): """Implements a mixed policy loss combining GRPO and SFT losses. This loss function applies different loss components to data based on whether it comes from an expert or not, as indicated by `expert_mask`. It combines: - GRPO loss (self.grpo_loss_fn) for non-expert data - SFT loss (self.sft_loss_fn) for expert data the weight of SFT loss is globally controled by `mu_schedule` function the tokenwise weights are calculated using different SFT loss formulas The per-sample weights are normalized using either `experience_per_gpu` or `gradient_accumulation`, depending on whether dynamic batch sizing is enabled, to ensure consistent weighting across different batches of the same type experiences. """
[docs] def __init__( self, backend: str = "verl", mu_warmup_steps: int = 0, mu_decay_steps: int = 0, mu_peak: float = 0.1, mu_valley: float = 0.1, enable_phi_function: bool = True, clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, use_dynamic_bsz: Optional[bool] = None, ppo_mini_batch_size: int = 1, ppo_micro_batch_size_per_gpu: int = 1, ngpus_trainer: int = 1, train_batch_size_usual: int = 1, train_batch_size_expert: int = 1, use_token_level_loss_in_sft: bool = True, ) -> None: super().__init__(backend=backend) self.mu_warmup_steps = mu_warmup_steps self.mu_decay_steps = mu_decay_steps self.mu_peak = mu_peak self.mu_valley = mu_valley self.enable_phi_function = enable_phi_function self.use_dynamic_bsz = use_dynamic_bsz self.experience_per_gpu = ppo_mini_batch_size // ngpus_trainer self.gradient_accumulation = ppo_mini_batch_size // ppo_micro_batch_size_per_gpu self.train_batch_size_usual = train_batch_size_usual // ngpus_trainer self.train_batch_size_expert = train_batch_size_expert // ngpus_trainer self.grpo_loss_fn = PPOPolicyLossFn( clip_range=clip_range, clip_range_low=clip_range_low, clip_range_high=clip_range_high, ) if enable_phi_function: self.sft_loss_fn = SFTPhiLossFn(use_token_level_loss=use_token_level_loss_in_sft) else: self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft)
def __call__( # type: ignore self, logprob: torch.Tensor, old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, expert_mask: torch.Tensor, step: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: assert ( len(expert_mask) == logprob.shape[0] ), f"Error: {len(expert_mask)=} != {logprob.shape[0]=}" assert len(step) == logprob.shape[0], f"Error: {len(step)=} != {logprob.shape[0]=}" assert ( step.max().item() == step.min().item() ), f"Error: {step.max().item()} != {step.min().item()}" current_step = step.max().item() n_usual_exp = torch.sum(~expert_mask).item() n_expert_exp = torch.sum(expert_mask).item() if self.use_dynamic_bsz: per_micro_batch_weight_usual = self.experience_per_gpu / ( logprob.shape[0] * self.train_batch_size_usual ) per_micro_batch_weight_expert = self.experience_per_gpu / ( logprob.shape[0] * self.train_batch_size_expert ) else: per_micro_batch_weight_usual = self.gradient_accumulation / self.train_batch_size_usual # type: ignore per_micro_batch_weight_expert = self.gradient_accumulation / self.train_batch_size_expert # type: ignore if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( logprob[~expert_mask], old_logprob[~expert_mask], action_mask[~expert_mask], advantages[~expert_mask], **kwargs, ) grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual grpo_metrics = { k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() } else: grpo_loss = torch.tensor(0.0, device=logprob.device) grpo_metrics = {} # SFT Loss (expert) if n_expert_exp > 0: sft_loss, sft_metrics = self.sft_loss_fn( logprob[expert_mask], action_mask[expert_mask], ) sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert sft_metrics = { k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() } else: sft_loss = torch.tensor(0.0, device=logprob.device) sft_metrics = {} mu = mu_schedule_function( current_step, self.mu_warmup_steps, self.mu_decay_steps, self.mu_peak, self.mu_valley ) loss = (1 - mu) * grpo_loss + mu * sft_loss metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) metrics.update({"loss": loss.item(), "mu": mu}) return loss, metrics
[docs] @classmethod def default_args(cls) -> Dict: """ mu_warmup_steps: int, mu_decay_steps: int, mu_peak: float, mu_valley: float """ return { "mu_warmup_steps": 0, "mu_decay_steps": 0, "mu_peak": 0.1, "mu_valley": 0.1, "clip_range": 0.2, "enable_phi_function": True, }