Source code for trinity.algorithm.policy_loss_fn.rec_policy_loss

"""REC-token policy loss function.
"""

from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


[docs] @POLICY_LOSS_FN.register_module("rec") class RECPolicyLossFn(PolicyLossFn):
[docs] def __init__( self, backend: str = "verl", epsilon_low: float = 0.2, epsilon_high: float = 0.2, epsilon_low_prime: float = 0.4, epsilon_high_prime: float = 0.4, clip_mode: str = "none", weight: str = "none", regularizer: str = "none", regularizer_coef: float = 0.0, temp: float = 1.0, ) -> None: super().__init__(backend=backend) self.epsilon_low = epsilon_low self.epsilon_high = epsilon_high assert 0.0 < self.epsilon_low <= 1.0, f"Invalid epsilon_low: {self.epsilon_low}" assert 0.0 < self.epsilon_high, f"Invalid epsilon_high: {self.epsilon_high}" self.epsilon_low_prime = epsilon_low_prime self.epsilon_high_prime = epsilon_high_prime assert ( 0.0 < self.epsilon_low_prime <= 1.0 ), f"Invalid epsilon_low_prime: {self.epsilon_low_prime}" assert ( 0.0 < self.epsilon_high_prime ), f"Invalid epsilon_high_prime: {self.epsilon_high_prime}" self.clip_mode = clip_mode assert self.clip_mode in [ "none", "one-side", "two-side", "ring", ], f"Invalid clip_mode: {self.clip_mode}" self.weight = weight assert self.weight in [ "none", "importance_sampling", "advantage", ], f"Invalid weight: {self.weight}" self.regularizer = regularizer assert self.regularizer in [ "none", "k2", "forward-kl", ], f"Invalid regularizer: {self.regularizer}" self.regularizer_coef = regularizer_coef assert self.regularizer_coef >= 0.0, f"Invalid regularizer_coef: {self.regularizer_coef}" self.temp = temp assert self.temp > 0.0, f"Invalid temp: {self.temp}"
def __call__( # type: ignore self, logprob: torch.Tensor, # [batch_size, seq_len] old_logprob: torch.Tensor, # [batch_size, seq_len] action_mask: torch.Tensor, # [batch_size, seq_len] advantages: torch.Tensor, # [batch_size, seq_len] **kwargs, ) -> Tuple[torch.Tensor, Dict]: """Calculate REC loss.""" # token-wise ratio = torch.exp(logprob - old_logprob).detach() # clipping if self.clip_mode == "two-side": is_in_range = (ratio >= (1 - self.epsilon_low)) * (ratio <= (1 + self.epsilon_high)) elif self.clip_mode == "one-side": is_in_range = (ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + ( advantages <= 0 ) * (ratio >= (1 - self.epsilon_low)) elif self.clip_mode == "ring": is_in_range = ( (ratio >= (1 - self.epsilon_low)) * (ratio <= (1 + self.epsilon_high)) + (advantages >= 0) * (ratio <= 1 - self.epsilon_low_prime) + (advantages <= 0) * (ratio >= 1 + self.epsilon_high_prime) ) else: # none is_in_range = torch.ones_like(ratio).bool() is_clipped_mask = ~is_in_range if self.weight == "importance_sampling": advantages = advantages * ratio # importance sampling elif self.weight == "advantage": weight = torch.exp(advantages / self.temp) advantages = advantages * weight # advantage weighting (unnormalized version) pg_losses = -advantages * logprob * is_in_range.float() if self.regularizer == "forward-kl": regularizer_losses = self.regularizer_coef * logprob pg_losses = pg_losses - regularizer_losses elif self.regularizer == "k2": # note that here we absorb the 1/2 in Kimi into \tau regularizer_losses = self.regularizer_coef * (logprob - old_logprob).square() pg_losses = pg_losses + regularizer_losses pg_loss = masked_mean(pg_losses, action_mask) pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask) metrics = { "pg_clipfrac": pg_clipfrac.item(), "pg_loss": pg_loss.detach().item(), } return pg_loss, metrics
[docs] @classmethod def default_args(cls) -> Dict: return { "epsilon_low": 0.2, "epsilon_high": 0.2, "epsilon_low_prime": 0.6, "epsilon_high_prime": 2, "clip_mode": "none", "weight": "none", "regularizer": "none", "regularizer_coef": 0.0, "temp": 1.0, }