trinity.algorithm.kl_fn

Submodules

trinity.algorithm.kl_fn.kl_fn module

KL penalty and loss.

Ref: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/utils.py

class trinity.algorithm.kl_fn.kl_fn.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: ABC

KL penalty and loss.

__init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[source]
update_kl_coef(current_kl: float, batch_size: int) None[source]

Update kl coefficient.

apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][source]

Apply KL penalty to reward. Only support DataProto input for now.

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict][source]

Compute KL loss.

abstract calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

classmethod default_args()[source]

Get the default initialization arguments.

class trinity.algorithm.kl_fn.kl_fn.DummyKLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: KLFn

Dummy KL function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][source]

Apply KL penalty to reward. Only support DataProto input for now.

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict][source]

Compute KL loss.

class trinity.algorithm.kl_fn.kl_fn.K1Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: KLFn

KL K1 function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

class trinity.algorithm.kl_fn.kl_fn.K2Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: KLFn

KL K2 function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

class trinity.algorithm.kl_fn.kl_fn.K3Fn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: KLFn

KL K3 function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

class trinity.algorithm.kl_fn.kl_fn.AbsFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: KLFn

KL Abs function.

calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

Module contents

class trinity.algorithm.kl_fn.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: ABC

KL penalty and loss.

__init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[source]
update_kl_coef(current_kl: float, batch_size: int) None[source]

Update kl coefficient.

apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][source]

Apply KL penalty to reward. Only support DataProto input for now.

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict][source]

Compute KL loss.

abstract calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

classmethod default_args()[source]

Get the default initialization arguments.