trinity.algorithm.kl_fn package#
Submodules#
Module contents#
- class trinity.algorithm.kl_fn.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#
Bases:
ABCKL penalty and loss.
- __init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[source]#
- apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][source]#
Apply KL penalty to reward. Only support DataProto input for now.
- calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor, loss_agg_mode: str, old_logprob: Tensor | None = None) Tuple[Tensor, Dict][source]#
Compute KL loss.
- Parameters:
logprob β Log probabilities from current policy
ref_logprob β Log probabilities from reference policy
response_mask β Mask for valid response tokens
loss_agg_mode β Loss aggregation mode
old_logprob β Log probabilities from old policy (for importance sampling)
- abstract calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[source]#
Compute KL divergence between logprob and ref_logprob.
- Parameters:
logprob β Log probabilities from current policy
ref_logprob β Log probabilities from reference policy
old_logprob β Log probabilities from old policy (for importance sampling)