Source code for trinity.algorithm.policy_loss_fn.policy_loss_fn

import inspect
from abc import ABC, ABCMeta, abstractmethod
from typing import Dict, Tuple

import torch

from trinity.algorithm.key_mapper import ALL_MAPPERS
from trinity.utils.registry import Registry

POLICY_LOSS_FN = Registry("policy_loss_fn")


[docs] class PolicyLossFnMeta(ABCMeta): """Metaclass for policy loss functions that handles parameter name mapping and filtering.""" ignore_keys = {"self", "kwargs", "logprob"} # Keys to exclude from parameter selection def __new__(cls, name, bases, dct): """ Metaclass constructor that automatically generates parameter handling logic. For example with `PPOPolicyLossFn` class: .. code-block:: python class PPOPolicyLossFn(PolicyLossFn): ... def __call__( self, logprob: torch.Tensor, old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: ... This metaclass analyzes the __call__ method's parameters to: 1. Generate _select_keys containing all non-ignored parameters 2. Create select_keys property that maps parameters to trainer-specific names 3. Apply decorator to automatically convert input parameter names using the mapper """ signature = inspect.signature(dct["__call__"]) param_names = [ key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys ] dct["_select_keys"] = param_names # Property to return trainer-specific parameter names def select_keys(self): """Returns parameter keys mapped to the specific training framework's naming convention.""" keys = [self.mapper.from_trinity(key) for key in self._select_keys] return keys # Decorator to handle parameter name conversion before calling __call__ def decorator(func): def wrapper(self, *args, **kwargs): """Filters and converts parameter names according to the training framework's convention.""" new_kwargs = {} for key, value in kwargs.items(): key = self.mapper.to_trinity(key) if key == "logprob" or key in self._select_keys: # remove unused keys new_kwargs[key] = value return func(self, *args, **new_kwargs) return wrapper # Add the property and decorated method to the class dct["select_keys"] = property(select_keys) dct["__call__"] = decorator(dct["__call__"]) return super().__new__(cls, name, bases, dct)
[docs] class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta): """ Abstract base class for policy loss functions. This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks. """
[docs] def __init__(self, backend: str = "verl"): """ Initialize the policy loss function. Args: backend: The training framework/backend to use (e.g., "verl") """ self.backend = backend self.mapper = ALL_MAPPERS[self.backend]
@abstractmethod def __call__( self, logprob: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: """ Calculate the policy loss. Args: logprob (`torch.Tensor`): The log probability generated by the policy model. Kwargs (optional): old_logprob (`torch.Tensor`): The log probability generated by the reference model. action_mask (`torch.Tensor`): The action mask. advantages (`torch.Tensor`): The advantages. kwargs (`Dict`): The step-level parameters for calculating the policy loss. Returns: `torch.Tensor`: Policy loss `Dict`: The metrics for logging. """
[docs] @classmethod @abstractmethod def default_args(cls) -> Dict: """ Get default initialization arguments for this loss function. Returns: `Dict`: The default init arguments for the policy loss function. """