trinity.algorithm.policy_loss_fn.policy_loss_fn module

class trinity.algorithm.policy_loss_fn.policy_loss_fn.PolicyLossFnMeta(name, bases, dct)[source]

Bases: ABCMeta

Metaclass for policy loss functions that handles parameter name mapping and filtering.

ignore_keys = {'kwargs', 'logprob', 'self'}
class trinity.algorithm.policy_loss_fn.policy_loss_fn.PolicyLossFn(backend: str = 'verl')[source]

Bases: ABC

Abstract base class for policy loss functions.

This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.

__init__(backend: str = 'verl')[source]

Initialize the policy loss function.

Parameters:

backend – The training framework/backend to use (e.g., “verl”)

abstract classmethod default_args() Dict[source]

Get default initialization arguments for this loss function.

Returns:

The default init arguments for the policy loss function.

Return type:

Dict

property select_keys

Returns parameter keys mapped to the specific training framework’s naming convention.