Source code for trinity.algorithm.policy_loss_fn.sft_loss
"""SFT loss function."""
from typing import Dict, Tuple
import torch
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_loss
[docs]
@POLICY_LOSS_FN.register_module("sft")
class SFTLossFn(PolicyLossFn):
[docs]
def __init__(self, backend: str = "verl", loss_agg_mode: str = "token-mean") -> None:
super().__init__(backend=backend)
self.loss_agg_mode = loss_agg_mode
def __call__( # type: ignore
self,
logprob: torch.Tensor,
action_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
sft_loss = masked_loss(-logprob, action_mask, loss_agg_mode=self.loss_agg_mode)
return sft_loss, {"sft_loss": sft_loss.detach().item()}
[docs]
@classmethod
def default_args(cls):
return {
"loss_agg_mode": "token-mean",
}