Source code for trinity.algorithm.policy_loss_fn.sft_loss

"""SFT loss function."""

from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


[docs] @POLICY_LOSS_FN.register_module("sft") class SFTLossFn(PolicyLossFn):
[docs] def __init__(self, backend: str = "verl", use_token_level_loss: bool = True) -> None: super().__init__(backend=backend) self.use_token_level_loss = use_token_level_loss
def __call__( # type: ignore self, logprob: torch.Tensor, action_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: if self.use_token_level_loss: sft_loss = masked_mean(-logprob, action_mask) else: sft_loss = masked_mean(-logprob, action_mask, axis=1).mean() return sft_loss, {"sft_loss": sft_loss.detach().item()}
[docs] @classmethod def default_args(cls): return { "use_token_level_loss": True, }