Source code for trinity.algorithm.policy_loss_fn.sppo_loss_fn
"""sPPO-token policy loss function.
Relevant paper: https://arxiv.org/abs/2108.05828.
"""
from typing import Dict, Tuple
import torch
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean
[docs]
@POLICY_LOSS_FN.register_module("sppo")
class sPPOPolicyLossFn(PolicyLossFn):
[docs]
def __init__(
self,
backend: str = "verl",
epsilon: float = 0.3,
) -> None:
super().__init__(backend=backend)
self.epsilon = epsilon
def __call__( # type: ignore
self,
logprob: torch.Tensor, # [batch_size, seq_len]
old_logprob: torch.Tensor, # [batch_size, seq_len]
action_mask: torch.Tensor, # [batch_size, seq_len]
advantages: torch.Tensor, # [batch_size, seq_len]
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""Calculate sPPO loss.
The formula is as follows:
advantages*log(clip(ratio, 1/(1+epsilon), 1+epsilon))
ratio = exp(logprob - old_logprob)
"""
#
# token-wise
ratio = torch.exp(logprob - old_logprob).detach()
is_in_range = (ratio >= (1 / (1 + self.epsilon))) * (ratio <= (1 + self.epsilon))
is_clipped_mask = ~is_in_range
pg_losses = -advantages * (logprob - old_logprob) * is_in_range.float()
pg_loss = masked_mean(pg_losses, action_mask)
pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask)
metrics = {
"pg_clipfrac": pg_clipfrac.item(),
"pg_loss": pg_loss.detach().item(),
}
return pg_loss, metrics
[docs]
@classmethod
def default_args(cls) -> Dict:
return {
"epsilon": 0.3,
}