trinity.algorithm.kl_fn.kl_fn module#
KL penalty and loss.
Ref: volcengine/verl volcengine/verl OpenRLHF/OpenRLHF
- class trinity.algorithm.kl_fn.kl_fn.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#
Bases:
ABC
KL penalty and loss.
- __init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None [source]#
- apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict] [source]#
Apply KL penalty to reward. Only support DataProto input for now.
- calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict] [source]#
Compute KL loss.
- class trinity.algorithm.kl_fn.kl_fn.DummyKLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#
Bases:
KLFn
Dummy KL function.
- calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor [source]#
Compute KL divergence between logprob and ref_logprob.
- class trinity.algorithm.kl_fn.kl_fn.K1Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#
Bases:
KLFn
KL K1 function.
- class trinity.algorithm.kl_fn.kl_fn.K2Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#
Bases:
KLFn
KL K2 function.
- class trinity.algorithm.kl_fn.kl_fn.K3Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#
Bases:
KLFn
KL K3 function.