trinity.algorithm package
Subpackages
- trinity.algorithm.advantage_fn package
- Submodules
- trinity.algorithm.advantage_fn.advantage_fn module
- trinity.algorithm.advantage_fn.grpo_advantage module
- trinity.algorithm.advantage_fn.multi_step_grpo_advantage module
- trinity.algorithm.advantage_fn.opmd_advantage module
- trinity.algorithm.advantage_fn.ppo_advantage module
- trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage module
- trinity.algorithm.advantage_fn.remax_advantage module
- trinity.algorithm.advantage_fn.rloo_advantage module
- Module contents
- Submodules
- trinity.algorithm.entropy_loss_fn package
- trinity.algorithm.kl_fn package
- trinity.algorithm.policy_loss_fn package
- Submodules
- trinity.algorithm.policy_loss_fn.chord_policy_loss module
- trinity.algorithm.policy_loss_fn.dpo_loss module
- trinity.algorithm.policy_loss_fn.gspo_policy_loss module
- trinity.algorithm.policy_loss_fn.mix_policy_loss module
- trinity.algorithm.policy_loss_fn.opmd_policy_loss module
- trinity.algorithm.policy_loss_fn.policy_loss_fn module
- trinity.algorithm.policy_loss_fn.ppo_policy_loss module
- trinity.algorithm.policy_loss_fn.sft_loss module
- Module contents
- Submodules
- trinity.algorithm.sample_strategy package
Submodules
Module contents
- class trinity.algorithm.AlgorithmType[source]
Bases:
ABC
- use_critic: bool
- use_reference: bool
- compute_advantage_in_trainer: bool
- can_balance_batch: bool
- schema: type
- class trinity.algorithm.AdvantageFn[source]
Bases:
ABC
- class trinity.algorithm.PolicyLossFn(backend: str = 'verl')[source]
Bases:
ABC
Abstract base class for policy loss functions.
This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.
- __init__(backend: str = 'verl')[source]
Initialize the policy loss function.
- Parameters:
backend – The training framework/backend to use (e.g., “verl”)
- abstract classmethod default_args() Dict [source]
Get default initialization arguments for this loss function.
- Returns:
The default init arguments for the policy loss function.
- Return type:
Dict
- property select_keys
Returns parameter keys mapped to the specific training framework’s naming convention.
- class trinity.algorithm.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]
Bases:
ABC
KL penalty and loss.
- __init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None [source]
- apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict] [source]
Apply KL penalty to reward. Only support DataProto input for now.
- calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict] [source]
Compute KL loss.
- class trinity.algorithm.SampleStrategy(buffer_config: BufferConfig, **kwargs)[source]
Bases:
ABC
- __init__(buffer_config: BufferConfig, **kwargs) None [source]
- abstract async sample(step: int) Tuple[Experiences, Dict, List] [source]
Sample data from buffer.
- Parameters:
step (int) – The step number of current step.
- Returns:
The sampled Experiences data. Dict: Metrics for logging. List: Representative data for logging.
- Return type:
Experiences