trinity.algorithm.policy_loss_fn.sapo_policy_loss module#
SAPO policy loss function. Soft Adaptive Policy Optimization (SAPO) is a reinforcement learning algorithm that uses a smooth, temperature-controlled soft gate instead of hard clipping.
Refer to the SAPO paper for details. https://arxiv.org/abs/2511.20347
- class trinity.algorithm.policy_loss_fn.sapo_policy_loss.SAPOPolicyLossFn(backend: str = 'verl', tau_pos: float = 1.0, tau_neg: float = 1.05, loss_agg_mode: str = 'token-mean')[source]#
Bases:
PolicyLossFn- __init__(backend: str = 'verl', tau_pos: float = 1.0, tau_neg: float = 1.05, loss_agg_mode: str = 'token-mean') None[source]#
Initialize SAPO policy loss function.
- Parameters:
backend – The training framework/backend to use (e.g., “verl”)
tau_pos – Temperature for positive advantages (τ_pos), default 1.0
tau_neg – Temperature for negative advantages (τ_neg), default 1.05, should be >= tau_pos
loss_agg_mode – Mode for aggregating loss across tokens
- soft_gate_function(ratio: Tensor, advantages: Tensor) Tensor[source]#
Compute the soft gate function f_{i,t}(x).
- The soft gate function is defined as:
f_{i,t}(x) = σ(τ_{i,t} * (x - 1)) * 4 / τ_{i,t}
- where:
σ is the sigmoid function
τ_{i,t} is the asymmetric temperature (tau_pos or tau_neg)
x is the importance sampling ratio r_{i,t}(θ)
- Parameters:
ratio – Token-level importance sampling ratio r_{i,t}(θ)
advantages – Normalized advantage function Â_i (same for all tokens in a sequence)
- Returns:
The soft gate values for each token
- classmethod default_args() Dict[source]#
Get default initialization arguments for SAPO.
- Default configuration (from the SAPO paper):
tau_pos: 1.0 (temperature for positive advantages)
tau_neg: 1.05 (temperature for negative advantages)
loss_agg_mode: “token-mean” (average over tokens)
The asymmetric temperatures (tau_neg > tau_pos) help stabilize training by more aggressively suppressing updates from tokens with negative advantages.
- Returns:
Dictionary of default arguments
- property select_keys#
Returns parameter keys mapped to the specific training framework’s naming convention.