trinity.algorithm.kl_fn.kl_fn module#

KL penalty and loss.

Ref: volcengine/verl volcengine/verl OpenRLHF/OpenRLHF

class trinity.algorithm.kl_fn.kl_fn.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#

Bases: ABC

KL penalty and loss.

__init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[source]#
update_kl_coef(current_kl: float, batch_size: int) None[source]#

Update kl coefficient.

apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][source]#

Apply KL penalty to reward. Only support DataProto input for now.

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor, loss_agg_mode: str, old_logprob: Tensor | None = None) Tuple[Tensor, Dict][source]#

Compute KL loss.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • response_mask – Mask for valid response tokens

  • loss_agg_mode – Loss aggregation mode

  • old_logprob – Log probabilities from old policy (for importance sampling)

abstract calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#

Compute KL divergence between logprob and ref_logprob.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • old_logprob – Log probabilities from old policy (for importance sampling)

classmethod default_args()[source]#

Get the default initialization arguments.

class trinity.algorithm.kl_fn.kl_fn.DummyKLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#

Bases: KLFn

Dummy KL function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#

Compute KL divergence between logprob and ref_logprob.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • old_logprob – Log probabilities from old policy (for importance sampling)

apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][source]#

Apply KL penalty to reward. Only support DataProto input for now.

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor, loss_agg_mode: str, old_logprob: Tensor | None = None) Tuple[Tensor, Dict][source]#

Compute KL loss.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • response_mask – Mask for valid response tokens

  • loss_agg_mode – Loss aggregation mode

  • old_logprob – Log probabilities from old policy (for importance sampling)

class trinity.algorithm.kl_fn.kl_fn.K1Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#

Bases: KLFn

KL K1 function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#

Compute KL divergence between logprob and ref_logprob.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • old_logprob – Log probabilities from old policy (for importance sampling)

class trinity.algorithm.kl_fn.kl_fn.K2Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#

Bases: KLFn

KL K2 function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#

Compute KL divergence between logprob and ref_logprob.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • old_logprob – Log probabilities from old policy (for importance sampling)

class trinity.algorithm.kl_fn.kl_fn.K3Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#

Bases: KLFn

KL K3 function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#

Compute KL divergence between logprob and ref_logprob.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • old_logprob – Log probabilities from old policy (for importance sampling)

class trinity.algorithm.kl_fn.kl_fn.LowVarKLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#

Bases: KLFn

Low Variance KL function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#

Compute KL divergence between logprob and ref_logprob.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • old_logprob – Log probabilities from old policy (for importance sampling)

class trinity.algorithm.kl_fn.kl_fn.AbsFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#

Bases: KLFn

KL Abs function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#

Compute KL divergence between logprob and ref_logprob.

Parameters:
  • logprob – Log probabilities from current policy

  • ref_logprob – Log probabilities from reference policy

  • old_logprob – Log probabilities from old policy (for importance sampling)

class trinity.algorithm.kl_fn.kl_fn.CorrectedK3Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#

Bases: KLFn

Corrected K3 function with importance sampling.

This method applies importance sampling correction to the standard K3 KL divergence. The corrected KL is computed as:

KL_corrected = (Ο€_ΞΈ / Ο€_old) * KL_standard(Ο€_ref || Ο€_ΞΈ)

where:
  • Ο€_ΞΈ: current policy

  • Ο€_old: old policy (from rollout)

  • Ο€_ref: reference policy

  • KL_standard: exp(log(Ο€_ref/Ο€_ΞΈ)) - log(Ο€_ref/Ο€_ΞΈ) - 1

If old_logprob is not provided, it falls back to standard K3.

calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#

Compute corrected K3 KL divergence with importance sampling.

Parameters:
  • logprob – Log probabilities from current policy (log Ο€_ΞΈ)

  • ref_logprob – Log probabilities from reference policy (log Ο€_ref)

  • old_logprob – Log probabilities from old policy (log Ο€_old), optional

Returns:

KL divergence tensor with same shape as input