trinity.algorithm package

Subpackages

Submodules

Module contents

class trinity.algorithm.AlgorithmType[source]

Bases: ABC

use_critic: bool
use_reference: bool
compute_advantage_in_trainer: bool
can_balance_batch: bool
schema: type
abstract classmethod default_config() Dict[source]
classmethod name() str[source]
classmethod check_config(config: Config) None[source]
class trinity.algorithm.AdvantageFn[source]

Bases: ABC

abstract classmethod default_args() Dict[source]
Returns:

The default init arguments for the advantage function.

Return type:

Dict

classmethod compute_in_trainer() bool[source]

Whether the advantage should be computed in the trainer loop.

class trinity.algorithm.PolicyLossFn(backend: str = 'verl')[source]

Bases: ABC

Abstract base class for policy loss functions.

This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.

__init__(backend: str = 'verl')[source]

Initialize the policy loss function.

Parameters:

backend – The training framework/backend to use (e.g., “verl”)

abstract classmethod default_args() Dict[source]

Get default initialization arguments for this loss function.

Returns:

The default init arguments for the policy loss function.

Return type:

Dict

property select_keys

Returns parameter keys mapped to the specific training framework’s naming convention.

class trinity.algorithm.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: ABC

KL penalty and loss.

__init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[source]
update_kl_coef(current_kl: float, batch_size: int) None[source]

Update kl coefficient.

apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][source]

Apply KL penalty to reward. Only support DataProto input for now.

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict][source]

Compute KL loss.

abstract calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

classmethod default_args()[source]

Get the default initialization arguments.

class trinity.algorithm.EntropyLossFn[source]

Bases: ABC

Entropy loss function.

classmethod default_args() Dict[source]
Returns:

The default arguments for the entropy loss function.

Return type:

Dict

class trinity.algorithm.SampleStrategy(buffer_config: BufferConfig, **kwargs)[source]

Bases: ABC

__init__(buffer_config: BufferConfig, **kwargs) None[source]
abstract async sample(step: int) Tuple[Experiences, Dict, List][source]

Sample data from buffer.

Parameters:

step (int) – The step number of current step.

Returns:

The sampled Experiences data. Dict: Metrics for logging. List: Representative data for logging.

Return type:

Experiences

abstract classmethod default_args() dict[source]

Get the default arguments of the sample strategy.