"""PPO policy loss function.
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
from typing import Dict, Optional, Tuple
import torch
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean
[docs]
@POLICY_LOSS_FN.register_module("ppo")
class PPOPolicyLossFn(PolicyLossFn):
[docs]
def __init__(
self,
backend: str = "verl",
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
) -> None:
super().__init__(backend=backend)
if clip_range_low is None:
self.clip_range_low = clip_range
else:
self.clip_range_low = clip_range_low
if clip_range_high is None:
self.clip_range_high = clip_range
else:
self.clip_range_high = clip_range_high
assert self.clip_range_low is not None, "clip_range_low must be specified."
assert self.clip_range_high is not None, "clip_range_high must be specified."
def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob
ratio = torch.exp(negative_approx_kl)
ppo_kl = masked_mean(-negative_approx_kl, action_mask)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
)
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), action_mask)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
metrics = {
"pg_clipfrac": pg_clipfrac.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"pg_loss": pg_loss.detach().item(),
}
return pg_loss, metrics
[docs]
@classmethod
def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
}