trinity.algorithm.kl_fn package#

Submodules#

Module contents#

class trinity.algorithm.kl_fn.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#

基类:ABC

KL penalty and loss.

__init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[源代码]#
apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][源代码]#

Apply KL penalty to reward. Only support DataProto input for now.

abstractmethod calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#

Compute KL divergence between logprob and ref_logprob.

参数:
  • logprob -- Log probabilities from current policy

  • ref_logprob -- Log probabilities from reference policy

  • old_logprob -- Log probabilities from old policy (for importance sampling)

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor, loss_agg_mode: str, old_logprob: Tensor | None = None) Tuple[Tensor, Dict][源代码]#

Compute KL loss.

参数:
  • logprob -- Log probabilities from current policy

  • ref_logprob -- Log probabilities from reference policy

  • response_mask -- Mask for valid response tokens

  • loss_agg_mode -- Loss aggregation mode

  • old_logprob -- Log probabilities from old policy (for importance sampling)

classmethod default_args()[源代码]#

Get the default initialization arguments.

update_kl_coef(current_kl: float, batch_size: int) None[源代码]#

Update kl coefficient.