Source code for trinity.algorithm.policy_loss_fn.topr_policy_loss

"""TOPR policy loss function.
Refer to https://arxiv.org/pdf/2503.14286v1
"""
from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


[docs] @POLICY_LOSS_FN.register_module("topr") class TOPRPolicyLossFn(PolicyLossFn):
[docs] def __init__( self, backend: str = "verl", advantage_threshold: float = 0.0, ) -> None: super().__init__(backend=backend) self.advantage_threshold = advantage_threshold
def __call__( # type: ignore self, logprob: torch.Tensor, old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, # In TOPR, this is actually the rewards R(x,y) **kwargs, ) -> Tuple[torch.Tensor, Dict]: """ Compute TOPR policy loss. In TOPR: - α = [π(y|x)/μ(y|x)]_0^1 if R(x,y) <= threshold else 1 - loss = -sg(α) * r(x,y) * log π(y|x) """ # in Orginal TOPR paper, advantages are simply rewards # However, we can use advantages as rewards(Baseline Trick) rewards = advantages # Compute ratio π(y|x) / μ(y|x) in log space for numerical stability log_ratio = logprob - old_logprob ratio = torch.exp(log_ratio) ratio_clamped = torch.clamp(ratio, min=0.0, max=1.0) # Apply TOPR's conditional weighting: # α = ratio clamp min=0 max=1 if R(x,y) <= threshold else 1 alpha = torch.where( rewards <= self.advantage_threshold, ratio_clamped, torch.ones_like(ratio) ) # TOPR loss: l = -α * r(x,y) * log π(y|x) # We want to maximize α * r(x,y) * log π(y|x), so minimize the negative topr_loss = -alpha.detach() * rewards * logprob # detach alpha as it's used with stop-grad # Apply masking and compute mean loss = masked_mean(topr_loss, action_mask) # Average alpha value for monitoring avg_alpha = masked_mean(alpha, action_mask) metrics = { "topr_loss": loss.detach().item(), "avg_alpha": avg_alpha.detach().item(), "avg_ratio": masked_mean(ratio, action_mask).detach().item(), } return loss, metrics
[docs] @classmethod def default_args(cls) -> Dict: return { "advantage_threshold": 0.0, }