trinity.algorithm.kl_fn.kl_fn module#
KL penalty and loss.
Ref: volcengine/verl volcengine/verl OpenRLHF/OpenRLHF
- class trinity.algorithm.kl_fn.kl_fn.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#
基类:
ABCKL penalty and loss.
- __init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[源代码]#
- apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][源代码]#
Apply KL penalty to reward. Only support DataProto input for now.
- calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor, loss_agg_mode: str, old_logprob: Tensor | None = None) Tuple[Tensor, Dict][源代码]#
Compute KL loss.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
response_mask -- Mask for valid response tokens
loss_agg_mode -- Loss aggregation mode
old_logprob -- Log probabilities from old policy (for importance sampling)
- abstractmethod calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#
Compute KL divergence between logprob and ref_logprob.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
old_logprob -- Log probabilities from old policy (for importance sampling)
- class trinity.algorithm.kl_fn.kl_fn.DummyKLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#
基类:
KLFnDummy KL function.
- calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#
Compute KL divergence between logprob and ref_logprob.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
old_logprob -- Log probabilities from old policy (for importance sampling)
- apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][源代码]#
Apply KL penalty to reward. Only support DataProto input for now.
- calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor, loss_agg_mode: str, old_logprob: Tensor | None = None) Tuple[Tensor, Dict][源代码]#
Compute KL loss.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
response_mask -- Mask for valid response tokens
loss_agg_mode -- Loss aggregation mode
old_logprob -- Log probabilities from old policy (for importance sampling)
- class trinity.algorithm.kl_fn.kl_fn.K1Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#
基类:
KLFnKL K1 function.
- calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#
Compute KL divergence between logprob and ref_logprob.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
old_logprob -- Log probabilities from old policy (for importance sampling)
- class trinity.algorithm.kl_fn.kl_fn.K2Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#
基类:
KLFnKL K2 function.
- calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#
Compute KL divergence between logprob and ref_logprob.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
old_logprob -- Log probabilities from old policy (for importance sampling)
- class trinity.algorithm.kl_fn.kl_fn.K3Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#
基类:
KLFnKL K3 function.
- calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#
Compute KL divergence between logprob and ref_logprob.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
old_logprob -- Log probabilities from old policy (for importance sampling)
- class trinity.algorithm.kl_fn.kl_fn.LowVarKLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#
基类:
KLFnLow Variance KL function.
- calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#
Compute KL divergence between logprob and ref_logprob.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
old_logprob -- Log probabilities from old policy (for importance sampling)
- class trinity.algorithm.kl_fn.kl_fn.AbsFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#
基类:
KLFnKL Abs function.
- calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#
Compute KL divergence between logprob and ref_logprob.
- 参数:
logprob -- Log probabilities from current policy
ref_logprob -- Log probabilities from reference policy
old_logprob -- Log probabilities from old policy (for importance sampling)
- class trinity.algorithm.kl_fn.kl_fn.CorrectedK3Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#
基类:
KLFnCorrected K3 function with importance sampling.
This method applies importance sampling correction to the standard K3 KL divergence. The corrected KL is computed as:
KL_corrected = (π_θ / π_old) * KL_standard(π_ref || π_θ)
- where:
π_θ: current policy
π_old: old policy (from rollout)
π_ref: reference policy
KL_standard: exp(log(π_ref/π_θ)) - log(π_ref/π_θ) - 1
If old_logprob is not provided, it falls back to standard K3.
- calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#
Compute corrected K3 KL divergence with importance sampling.
- 参数:
logprob -- Log probabilities from current policy (log π_θ)
ref_logprob -- Log probabilities from reference policy (log π_ref)
old_logprob -- Log probabilities from old policy (log π_old), optional
- 返回:
KL divergence tensor with same shape as input