trinity.algorithm.policy_loss_fn.importance_sampling_policy_loss module#
The most simple Importance Sampling policy loss.
loss = -(prob_ratio * advantages).sum() where prob_ratio = exp(current_logprobs - sampling_logprobs)
Note: This loss is used for on-policy distillation.
- class trinity.algorithm.policy_loss_fn.importance_sampling_policy_loss.ImportanceSamplingLossFn(backend: str = 'verl', loss_agg_mode: str = 'token-mean')[source]#
Bases:
PolicyLossFnPure importance sampling loss without clipping.
loss = -(ratio * advantages) where ratio = exp(logprob - old_logprob)
- __init__(backend: str = 'verl', loss_agg_mode: str = 'token-mean') None[source]#
Initialize the policy loss function.
- Parameters:
backend – The training framework/backend to use (e.g., “verl”)
- classmethod default_args() Dict[source]#
Get default initialization arguments for this loss function.
- Returns:
The default init arguments for the policy loss function.
- Return type:
Dict
- property select_keys#
Returns parameter keys mapped to the specific training framework’s naming convention.