trinity.algorithm.policy_loss_fn package#
Submodules#
- trinity.algorithm.policy_loss_fn.chord_policy_loss module
- trinity.algorithm.policy_loss_fn.cispo_policy_loss module
- trinity.algorithm.policy_loss_fn.dpo_loss module
- trinity.algorithm.policy_loss_fn.gspo_policy_loss module
- trinity.algorithm.policy_loss_fn.mix_policy_loss module
- trinity.algorithm.policy_loss_fn.opmd_policy_loss module
- trinity.algorithm.policy_loss_fn.policy_loss_fn module
- trinity.algorithm.policy_loss_fn.ppo_policy_loss module
- trinity.algorithm.policy_loss_fn.rec_policy_loss module
- trinity.algorithm.policy_loss_fn.sft_loss module
- trinity.algorithm.policy_loss_fn.sppo_loss_fn module
- trinity.algorithm.policy_loss_fn.topr_policy_loss module
Module contents#
- class trinity.algorithm.policy_loss_fn.PolicyLossFn(backend: str = 'verl')[source]#
Bases:
ABCAbstract base class for policy loss functions.
This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.
- __init__(backend: str = 'verl')[source]#
Initialize the policy loss function.
- Parameters:
backend – The training framework/backend to use (e.g., “verl”)
- abstract classmethod default_args() Dict[source]#
Get default initialization arguments for this loss function.
- Returns:
The default init arguments for the policy loss function.
- Return type:
Dict
- property select_keys#
Returns parameter keys mapped to the specific training framework’s naming convention.