trinity.algorithm

Subpackages

Submodules

trinity.algorithm.algorithm module

Algorithm classes.

class trinity.algorithm.algorithm.ConstantMeta(name, bases, namespace, **kwargs)[source]

Bases: ABCMeta

class trinity.algorithm.algorithm.AlgorithmType[source]

Bases: ABC

use_critic: bool
use_reference: bool
use_advantage: bool
can_balance_batch: bool
schema: type
abstract classmethod default_config() Dict[source]
classmethod name() str[source]
classmethod check_config(config: Config) None[source]
class trinity.algorithm.algorithm.SFTAlgorithm[source]

Bases: AlgorithmType

SFT Algorithm.

use_critic: bool = False
use_reference: bool = False
use_advantage: bool = False
can_balance_batch: bool = True
schema

alias of SFTDataModel

classmethod default_config() Dict[source]
class trinity.algorithm.algorithm.PPOAlgorithm[source]

Bases: AlgorithmType

PPO Algorithm.

use_critic: bool = True
use_reference: bool = True
use_advantage: bool = True
can_balance_batch: bool = True
schema

alias of ExperienceModel

classmethod default_config() Dict[source]
class trinity.algorithm.algorithm.GRPOAlgorithm[source]

Bases: AlgorithmType

GRPO algorithm.

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
can_balance_batch: bool = True
schema

alias of ExperienceModel

classmethod default_config() Dict[source]
class trinity.algorithm.algorithm.OPMDAlgorithm[source]

Bases: AlgorithmType

OPMD algorithm.

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
can_balance_batch: bool = True
schema

alias of ExperienceModel

classmethod default_config() Dict[source]
class trinity.algorithm.algorithm.DPOAlgorithm[source]

Bases: AlgorithmType

DPO algorithm.

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = False
can_balance_batch: bool = False
schema

alias of DPODataModel

classmethod default_config() Dict[source]
classmethod check_config(config: Config) None[source]
class trinity.algorithm.algorithm.MIXAlgorithm[source]

Bases: AlgorithmType

MIX algorithm.

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema

alias of ExperienceModel

classmethod default_config() Dict[source]

trinity.algorithm.algorithm_manager module

AlgorithmManager for switching between SFT and RFT.

class trinity.algorithm.algorithm_manager.AlgorithmManager(config: Config)[source]

Bases: object

__init__(config: Config)[source]
get_current_algorithm_config(global_steps: int)[source]
need_save(global_steps: int)[source]

trinity.algorithm.key_mapper module

Key Mapper

class trinity.algorithm.key_mapper.KeyMapper(to_trinity_map: Dict[str, str])[source]

Bases: object

__init__(to_trinity_map: Dict[str, str])[source]
to_trinity(key: str) str[source]
from_trinity(key: str) str[source]

trinity.algorithm.utils module

Common utils for algorithm module.

Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py

trinity.algorithm.utils.masked_sum(values, mask, axis=None)[source]

Compute mean of tensor with a masked values.

trinity.algorithm.utils.masked_mean(values, mask, axis=None)[source]

Compute mean of tensor with a masked values.

trinity.algorithm.utils.masked_var(values, mask, unbiased=True)[source]

Compute variance of tensor with masked values.

trinity.algorithm.utils.masked_whiten(values, mask, shift_mean=True)[source]

Whiten values by normalizing with mean and variance computed over mask.

Parameters:
  • values (torch.Tensor) – Input tensor.

  • mask (torch.Tensor) – Boolean tensor of same shape, selects elements for stats.

  • shift_mean (bool) – If True (default), output is zero-mean; if False, the original mean is re-added after scaling.

Returns:

Whitened tensor of same shape as values.

Return type:

torch.Tensor

trinity.algorithm.utils.prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict | None = None) dict[source]

Module contents

class trinity.algorithm.AlgorithmType[source]

Bases: ABC

use_critic: bool
use_reference: bool
use_advantage: bool
can_balance_batch: bool
schema: type
abstract classmethod default_config() Dict[source]
classmethod name() str[source]
classmethod check_config(config: Config) None[source]
class trinity.algorithm.AdvantageFn[source]

Bases: ABC

abstract classmethod default_args() Dict[source]
Returns:

The default init arguments for the advantage function.

Return type:

Dict

class trinity.algorithm.PolicyLossFn(backend: str = 'verl')[source]

Bases: ABC

Abstract base class for policy loss functions.

This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.

__init__(backend: str = 'verl')[source]

Initialize the policy loss function.

Parameters:

backend – The training framework/backend to use (e.g., “verl”)

abstract classmethod default_args() Dict[source]

Get default initialization arguments for this loss function.

Returns:

The default init arguments for the policy loss function.

Return type:

Dict

property select_keys

Returns parameter keys mapped to the specific training framework’s naming convention.

class trinity.algorithm.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]

Bases: ABC

KL penalty and loss.

__init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[source]
update_kl_coef(current_kl: float, batch_size: int) None[source]

Update kl coefficient.

apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][source]

Apply KL penalty to reward. Only support DataProto input for now.

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict][source]

Compute KL loss.

abstract calculate_kl(logprob: Tensor, ref_logprob: Tensor) Tensor[source]

Compute KL divergence between logprob and ref_logprob.

classmethod default_args()[source]

Get the default initialization arguments.

class trinity.algorithm.EntropyLossFn[source]

Bases: ABC

Entropy loss function.

classmethod default_args() Dict[source]
Returns:

The default arguments for the entropy loss function.

Return type:

Dict

class trinity.algorithm.SampleStrategy(buffer_config: BufferConfig, trainer_type: str, **kwargs)[source]

Bases: ABC

__init__(buffer_config: BufferConfig, trainer_type: str, **kwargs) None[source]
abstract sample(step: int) Tuple[Any, Dict, List][source]

Sample data from buffer.

Parameters:

step (int) – The step number of current step.

Returns:

The sampled data. Dict: Metrics for logging. List: Representative data for logging.

Return type:

Any

abstract warmup_state(step: int) Tuple[bool, bool][source]

Check the warmup state of the current step.

Parameters:

step (int) – The step number of current step.

Returns:

Current step is in warmup or not. bool: Warmup is finished on this step or not.

Return type:

bool

abstract classmethod default_args() dict[source]

Get the default arguments of the sample strategy.