trinity.algorithm.policy_loss_fn.dpo_loss module
DPO loss function.
- class trinity.algorithm.policy_loss_fn.dpo_loss.DPOLossFn(backend: str = 'verl', beta: float = 0.1, label_smoothing: float = 0.0)[source]
Bases:
PolicyLossFn
- __init__(backend: str = 'verl', beta: float = 0.1, label_smoothing: float = 0.0) None [source]
Initialize the policy loss function.
- Parameters:
backend – The training framework/backend to use (e.g., “verl”)
- classmethod default_args() Dict [source]
Get default initialization arguments for this loss function.
- Returns:
The default init arguments for the policy loss function.
- Return type:
Dict
- property select_keys
Returns parameter keys mapped to the specific training framework’s naming convention.