"""OPMD policy loss function."""
from typing import Dict, Tuple
import torch
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean
[docs]
@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
[docs]
def __init__(self, backend: str = "verl", tau: float = 1.0) -> None:
super().__init__(backend=backend)
self.tau = tau
def __call__( # type: ignore
self,
logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
pg_losses = -advantages * logprob
opmd_loss = masked_mean(pg_losses, action_mask)
opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}
[docs]
@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}