trinity.algorithm.policy_loss_fn package#

Submodules#

Module contents#

class trinity.algorithm.policy_loss_fn.PolicyLossFn(backend: str = 'verl')[源代码]#

基类:ABC

Abstract base class for policy loss functions.

This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.

__init__(backend: str = 'verl')[源代码]#

Initialize the policy loss function.

参数:

backend -- The training framework/backend to use (e.g., "verl")

abstractmethod classmethod default_args() Dict[源代码]#

Get default initialization arguments for this loss function.

返回:

The default init arguments for the policy loss function.

返回类型:

Dict

property select_keys#

Returns parameter keys mapped to the specific training framework's naming convention.