trinity.algorithm
Subpackages
- trinity.algorithm.advantage_fn
- Submodules
- trinity.algorithm.advantage_fn.advantage_fn module
- trinity.algorithm.advantage_fn.grpo_advantage module
- trinity.algorithm.advantage_fn.opmd_advantage module
- trinity.algorithm.advantage_fn.ppo_advantage module
- trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage module
- trinity.algorithm.advantage_fn.remax_advantage module
- trinity.algorithm.advantage_fn.rloo_advantage module
- Module contents
- trinity.algorithm.entropy_loss_fn
- trinity.algorithm.kl_fn
- trinity.algorithm.policy_loss_fn
- Submodules
- trinity.algorithm.policy_loss_fn.dpo_loss module
- trinity.algorithm.policy_loss_fn.mix_policy_loss module
- trinity.algorithm.policy_loss_fn.opmd_policy_loss module
- trinity.algorithm.policy_loss_fn.policy_loss_fn module
- trinity.algorithm.policy_loss_fn.ppo_policy_loss module
- trinity.algorithm.policy_loss_fn.sft_loss module
- Module contents
- trinity.algorithm.sample_strategy
- Submodules
- trinity.algorithm.sample_strategy.mix_sample_strategy module
- trinity.algorithm.sample_strategy.sample_strategy module
- trinity.algorithm.sample_strategy.utils module
- Module contents
Submodules
trinity.algorithm.algorithm module
Algorithm classes.
- class trinity.algorithm.algorithm.ConstantMeta(name, bases, namespace, **kwargs)[source]
Bases:
ABCMeta
- class trinity.algorithm.algorithm.AlgorithmType[source]
Bases:
ABC
- use_critic: bool
- use_reference: bool
- use_advantage: bool
- can_balance_batch: bool
- schema: type
- class trinity.algorithm.algorithm.SFTAlgorithm[source]
Bases:
AlgorithmType
SFT Algorithm.
- use_critic: bool = False
- use_reference: bool = False
- use_advantage: bool = False
- can_balance_batch: bool = True
- schema
alias of
SFTDataModel
- class trinity.algorithm.algorithm.PPOAlgorithm[source]
Bases:
AlgorithmType
PPO Algorithm.
- use_critic: bool = True
- use_reference: bool = True
- use_advantage: bool = True
- can_balance_batch: bool = True
- schema
alias of
ExperienceModel
- class trinity.algorithm.algorithm.GRPOAlgorithm[source]
Bases:
AlgorithmType
GRPO algorithm.
- use_critic: bool = False
- use_reference: bool = True
- use_advantage: bool = True
- can_balance_batch: bool = True
- schema
alias of
ExperienceModel
- class trinity.algorithm.algorithm.OPMDAlgorithm[source]
Bases:
AlgorithmType
OPMD algorithm.
- use_critic: bool = False
- use_reference: bool = True
- use_advantage: bool = True
- can_balance_batch: bool = True
- schema
alias of
ExperienceModel
- class trinity.algorithm.algorithm.DPOAlgorithm[source]
Bases:
AlgorithmType
DPO algorithm.
- use_critic: bool = False
- use_reference: bool = True
- use_advantage: bool = False
- can_balance_batch: bool = False
- schema
alias of
DPODataModel
- class trinity.algorithm.algorithm.MIXAlgorithm[source]
Bases:
AlgorithmType
MIX algorithm.
- use_critic: bool = False
- use_reference: bool = True
- use_advantage: bool = True
- use_rollout: bool = True
- can_balance_batch: bool = True
- schema
alias of
ExperienceModel
trinity.algorithm.algorithm_manager module
AlgorithmManager for switching between SFT and RFT.
trinity.algorithm.key_mapper module
Key Mapper
trinity.algorithm.utils module
Common utils for algorithm module.
Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py
- trinity.algorithm.utils.masked_sum(values, mask, axis=None)[source]
Compute mean of tensor with a masked values.
- trinity.algorithm.utils.masked_mean(values, mask, axis=None)[source]
Compute mean of tensor with a masked values.
- trinity.algorithm.utils.masked_var(values, mask, unbiased=True)[source]
Compute variance of tensor with masked values.
- trinity.algorithm.utils.masked_whiten(values, mask, shift_mean=True)[source]
Whiten values by normalizing with mean and variance computed over mask.
- Parameters:
values (torch.Tensor) – Input tensor.
mask (torch.Tensor) – Boolean tensor of same shape, selects elements for stats.
shift_mean (bool) – If True (default), output is zero-mean; if False, the original mean is re-added after scaling.
- Returns:
Whitened tensor of same shape as values.
- Return type:
torch.Tensor
Module contents
- class trinity.algorithm.AlgorithmType[source]
Bases:
ABC
- use_critic: bool
- use_reference: bool
- use_advantage: bool
- can_balance_batch: bool
- schema: type
- class trinity.algorithm.PolicyLossFn(backend: str = 'verl')[source]
Bases:
ABC
Abstract base class for policy loss functions.
This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.
- __init__(backend: str = 'verl')[source]
Initialize the policy loss function.
- Parameters:
backend – The training framework/backend to use (e.g., “verl”)
- abstract classmethod default_args() Dict [source]
Get default initialization arguments for this loss function.
- Returns:
The default init arguments for the policy loss function.
- Return type:
Dict
- property select_keys
Returns parameter keys mapped to the specific training framework’s naming convention.
- class trinity.algorithm.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]
Bases:
ABC
KL penalty and loss.
- __init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None [source]
- apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict] [source]
Apply KL penalty to reward. Only support DataProto input for now.
- calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict] [source]
Compute KL loss.
- class trinity.algorithm.SampleStrategy(buffer_config: BufferConfig, trainer_type: str, **kwargs)[source]
Bases:
ABC
- __init__(buffer_config: BufferConfig, trainer_type: str, **kwargs) None [source]
- abstract sample(step: int) Tuple[Any, Dict, List] [source]
Sample data from buffer.
- Parameters:
step (int) – The step number of current step.
- Returns:
The sampled data. Dict: Metrics for logging. List: Representative data for logging.
- Return type:
Any