trinity.algorithm.policy_loss_fn.policy_loss_fn module
- class trinity.algorithm.policy_loss_fn.policy_loss_fn.PolicyLossFnMeta(name, bases, dct)[source]
Bases:
ABCMeta
Metaclass for policy loss functions that handles parameter name mapping and filtering.
- ignore_keys = {'kwargs', 'logprob', 'self'}
- class trinity.algorithm.policy_loss_fn.policy_loss_fn.PolicyLossFn(backend: str = 'verl')[source]
Bases:
ABC
Abstract base class for policy loss functions.
This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.
- __init__(backend: str = 'verl')[source]
Initialize the policy loss function.
- Parameters:
backend – The training framework/backend to use (e.g., “verl”)
- abstract classmethod default_args() Dict [source]
Get default initialization arguments for this loss function.
- Returns:
The default init arguments for the policy loss function.
- Return type:
Dict
- property select_keys
Returns parameter keys mapped to the specific training framework’s naming convention.