Source code for trinity.algorithm.policy_loss_fn.opmd_policy_loss

"""OPMD policy loss function."""

from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


[docs] @POLICY_LOSS_FN.register_module("opmd") class OPMDPolicyLossFn(PolicyLossFn):
[docs] def __init__(self, backend: str = "verl", tau: float = 1.0) -> None: super().__init__(backend=backend) self.tau = tau
def __call__( # type: ignore self, logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: pg_losses = -advantages * logprob opmd_loss = masked_mean(pg_losses, action_mask) opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta) return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}
[docs] @classmethod def default_args(cls) -> Dict: return {"tau": 1.0}