Source code for trinity.algorithm.policy_loss_fn.cispo_policy_loss
"""CISPO policy loss function.
Refer to https://arxiv.org/abs/2506.13585 for details.
"""
from typing import Dict, Tuple
import torch
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean
[docs]
@POLICY_LOSS_FN.register_module("cispo")
class CISPOPolicyLossFn(PolicyLossFn):
[docs]
def __init__(
self,
backend: str = "verl",
clip_range_low: float = 1.0,
clip_range_high: float = 0.28,
enable_mask_clip: bool = False,
mask_clip_range_low: float = 1.0,
mask_clip_range_high: float = 0.28,
) -> None:
super().__init__(backend=backend)
self.clip_range_low = clip_range_low
self.clip_range_high = clip_range_high
self.enable_mask_clip = enable_mask_clip
self.mask_clip_range_low = mask_clip_range_low
self.mask_clip_range_high = mask_clip_range_high
def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob
ratio = torch.exp(negative_approx_kl)
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
ratio_clamped = torch.clamp(
ratio, min=1.0 - self.clip_range_low, max=1.0 + self.clip_range_high
)
# mask = 0 if ratio > 1.0 + self.clip_range_high and advantages > 0
# mask = 0 if ratio < 1.0 - self.clip_range_low and advantages < 0
# else 1
mask = torch.ones_like(ratio)
if self.enable_mask_clip:
mask = torch.where(
(ratio > 1.0 + self.mask_clip_range_high) & (advantages > 0),
torch.zeros_like(ratio),
mask,
)
mask = torch.where(
(ratio < 1.0 - self.mask_clip_range_low) & (advantages < 0),
torch.zeros_like(ratio),
mask,
)
cispo_loss = -advantages * ratio_clamped.detach() * mask.detach() * logprob
loss = masked_mean(cispo_loss, action_mask)
masked_frac = masked_mean(mask, action_mask)
metrics = {
"cispo_loss": loss.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"masked_frac": masked_frac.detach().item(),
}
return loss, metrics
[docs]
@classmethod
def default_args(cls) -> Dict:
"""
In the original paper:
we did not impose a lower bound on the IS weight by setting clip_range_low to a high value, instead, we only tuned clip_range_high
"""
return {
"clip_range_low": 1.0,
"clip_range_high": 0.28,
"enable_mask_clip": False,
"mask_clip_range_low": 1.0,
"mask_clip_range_high": 0.28,
}