Source code for trinity.algorithm.policy_loss_fn.policy_loss_fn
import inspect
from abc import ABC, ABCMeta, abstractmethod
from typing import Dict, Tuple
import torch
from trinity.algorithm.key_mapper import ALL_MAPPERS
from trinity.utils.registry import Registry
POLICY_LOSS_FN = Registry("policy_loss_fn")
[docs]
class PolicyLossFnMeta(ABCMeta):
"""Metaclass for policy loss functions that handles parameter name mapping and filtering."""
ignore_keys = {"self", "kwargs", "logprob"} # Keys to exclude from parameter selection
def __new__(cls, name, bases, dct):
"""
Metaclass constructor that automatically generates parameter handling logic.
For example with `PPOPolicyLossFn` class:
.. code-block:: python
class PPOPolicyLossFn(PolicyLossFn):
...
def __call__(
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
...
This metaclass analyzes the __call__ method's parameters to:
1. Generate _select_keys containing all non-ignored parameters
2. Create select_keys property that maps parameters to trainer-specific names
3. Apply decorator to automatically convert input parameter names using the mapper
"""
signature = inspect.signature(dct["__call__"])
param_names = [
key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys
]
dct["_select_keys"] = param_names
# Property to return trainer-specific parameter names
def select_keys(self):
"""Returns parameter keys mapped to the specific training framework's naming convention."""
keys = [self.mapper.from_trinity(key) for key in self._select_keys]
return keys
# Decorator to handle parameter name conversion before calling __call__
def decorator(func):
def wrapper(self, *args, **kwargs):
"""Filters and converts parameter names according to the training framework's convention."""
new_kwargs = {}
for key, value in kwargs.items():
key = self.mapper.to_trinity(key)
if key == "logprob" or key in self._select_keys: # remove unused keys
new_kwargs[key] = value
return func(self, *args, **new_kwargs)
return wrapper
# Add the property and decorated method to the class
dct["select_keys"] = property(select_keys)
dct["__call__"] = decorator(dct["__call__"])
return super().__new__(cls, name, bases, dct)
[docs]
class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta):
"""
Abstract base class for policy loss functions.
This class provides the interface for implementing different policy gradient loss functions
while handling parameter name mapping between different training frameworks.
"""
[docs]
def __init__(self, backend: str = "verl"):
"""
Initialize the policy loss function.
Args:
backend: The training framework/backend to use (e.g., "verl")
"""
self.backend = backend
self.mapper = ALL_MAPPERS[self.backend]
@abstractmethod
def __call__(
self,
logprob: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""
Calculate the policy loss.
Args:
logprob (`torch.Tensor`): The log probability generated by the policy model.
Kwargs (optional):
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
action_mask (`torch.Tensor`): The action mask.
advantages (`torch.Tensor`): The advantages.
kwargs (`Dict`): The step-level parameters for calculating the policy loss.
Returns:
`torch.Tensor`: Policy loss
`Dict`: The metrics for logging.
"""
[docs]
@classmethod
@abstractmethod
def default_args(cls) -> Dict:
"""
Get default initialization arguments for this loss function.
Returns:
`Dict`: The default init arguments for the policy loss function.
"""