trinity.algorithm.policy_loss_fn.sapo_policy_loss module#
SAPO policy loss function. Soft Adaptive Policy Optimization (SAPO) is a reinforcement learning algorithm that uses a smooth, temperature-controlled soft gate instead of hard clipping.
Refer to the SAPO paper for details. https://arxiv.org/abs/2511.20347
- class trinity.algorithm.policy_loss_fn.sapo_policy_loss.SAPOPolicyLossFn(backend: str = 'verl', tau_pos: float = 1.0, tau_neg: float = 1.05, loss_agg_mode: str = 'token-mean')[源代码]#
基类:
PolicyLossFn- __init__(backend: str = 'verl', tau_pos: float = 1.0, tau_neg: float = 1.05, loss_agg_mode: str = 'token-mean') None[源代码]#
Initialize SAPO policy loss function.
- 参数:
backend -- The training framework/backend to use (e.g., "verl")
tau_pos -- Temperature for positive advantages (τ_pos), default 1.0
tau_neg -- Temperature for negative advantages (τ_neg), default 1.05, should be >= tau_pos
loss_agg_mode -- Mode for aggregating loss across tokens
- soft_gate_function(ratio: Tensor, advantages: Tensor) Tensor[源代码]#
Compute the soft gate function f_{i,t}(x).
- The soft gate function is defined as:
f_{i,t}(x) = σ(τ_{i,t} * (x - 1)) * 4 / τ_{i,t}
- where:
σ is the sigmoid function
τ_{i,t} is the asymmetric temperature (tau_pos or tau_neg)
x is the importance sampling ratio r_{i,t}(θ)
- 参数:
ratio -- Token-level importance sampling ratio r_{i,t}(θ)
advantages -- Normalized advantage function Â_i (same for all tokens in a sequence)
- 返回:
The soft gate values for each token
- classmethod default_args() Dict[源代码]#
Get default initialization arguments for SAPO.
- Default configuration (from the SAPO paper):
tau_pos: 1.0 (temperature for positive advantages)
tau_neg: 1.05 (temperature for negative advantages)
loss_agg_mode: "token-mean" (average over tokens)
The asymmetric temperatures (tau_neg > tau_pos) help stabilize training by more aggressively suppressing updates from tokens with negative advantages.
- 返回:
Dictionary of default arguments
- property select_keys#
Returns parameter keys mapped to the specific training framework's naming convention.