Source code for trinity.algorithm.policy_loss_fn.mix_policy_loss

"""Mix policy loss function."""

from typing import Dict, Optional, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn


[docs] @POLICY_LOSS_FN.register_module("mix") class MIXPolicyLossFn(PolicyLossFn): """Implements a mixed policy loss combining GRPO and SFT losses. This loss function applies different loss components to data based on whether it comes from an expert or not, as indicated by `is_expert_mask`. It combines: - GRPO loss (self.grpo_loss_fn) for non-expert data - SFT loss (self.sft_loss_fn) for expert data - Weighting parameter `mu` The per-sample weights are normalized using either `experience_per_gpu` or `gradient_accumulation`, depending on whether dynamic batch sizing is enabled, to ensure consistent weighting across different batches of the same type experiences. """
[docs] def __init__( self, backend: str = "verl", mu: float = 0.1, clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, use_dynamic_bsz: Optional[bool] = None, repeat_times: int = 1, ppo_mini_batch_size: int = 1, ppo_micro_batch_size_per_gpu: int = 1, ngpus_trainer: int = 1, read_batch_size_usual: int = 1, read_batch_size_expert: int = 1, use_token_level_loss_in_sft: bool = True, ) -> None: super().__init__(backend=backend) self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer self.gradient_accumulation = ( ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu ) self.read_batch_size_usual = read_batch_size_usual // ngpus_trainer self.read_batch_size_expert = read_batch_size_expert // ngpus_trainer self.grpo_loss_fn = PPOPolicyLossFn( clip_range=clip_range, clip_range_low=clip_range_low, clip_range_high=clip_range_high, ) self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft)
def __call__( # type: ignore self, logprob: torch.Tensor, old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, is_expert_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: assert ( len(is_expert_mask) == logprob.shape[0] ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" n_usual_exp = torch.sum(~is_expert_mask).item() n_expert_exp = torch.sum(is_expert_mask).item() if self.use_dynamic_bsz: per_micro_batch_weight_usual = self.experience_per_gpu / ( logprob.shape[0] * self.read_batch_size_usual ) per_micro_batch_weight_expert = self.experience_per_gpu / ( logprob.shape[0] * self.read_batch_size_expert ) else: per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( logprob[~is_expert_mask], old_logprob[~is_expert_mask], action_mask[~is_expert_mask], advantages[~is_expert_mask], **kwargs, ) grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual grpo_metrics = { k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() } else: grpo_loss = torch.tensor(0.0, device=logprob.device) grpo_metrics = {} # SFT Loss (expert) if n_expert_exp > 0: sft_loss, sft_metrics = self.sft_loss_fn( logprob[is_expert_mask], action_mask[is_expert_mask], ) sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert sft_metrics = { k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() } else: sft_loss = torch.tensor(0.0, device=logprob.device) sft_metrics = {} loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) metrics.update({"loss": loss.item()}) return loss, metrics
[docs] @classmethod def default_args(cls) -> Dict: return { "mu": 0.1, "clip_range": 0.2, }