trinity.algorithm package#
Subpackages#
- trinity.algorithm.advantage_fn package
- Submodules
- trinity.algorithm.advantage_fn.advantage_fn module
- trinity.algorithm.advantage_fn.asymre_advantage module
- trinity.algorithm.advantage_fn.grpo_advantage module
- trinity.algorithm.advantage_fn.multi_step_grpo_advantage module
- trinity.algorithm.advantage_fn.opmd_advantage module
- trinity.algorithm.advantage_fn.ppo_advantage module
- trinity.algorithm.advantage_fn.reinforce_advantage module
- trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage module
- trinity.algorithm.advantage_fn.remax_advantage module
- trinity.algorithm.advantage_fn.rloo_advantage module
- Module contents
- Submodules
- trinity.algorithm.entropy_loss_fn package
- trinity.algorithm.kl_fn package
- trinity.algorithm.policy_loss_fn package
- Submodules
- trinity.algorithm.policy_loss_fn.chord_policy_loss module
- trinity.algorithm.policy_loss_fn.cispo_policy_loss module
- trinity.algorithm.policy_loss_fn.dpo_loss module
- trinity.algorithm.policy_loss_fn.gspo_policy_loss module
- trinity.algorithm.policy_loss_fn.mix_policy_loss module
- trinity.algorithm.policy_loss_fn.opmd_policy_loss module
- trinity.algorithm.policy_loss_fn.policy_loss_fn module
- trinity.algorithm.policy_loss_fn.ppo_policy_loss module
- trinity.algorithm.policy_loss_fn.sft_loss module
- trinity.algorithm.policy_loss_fn.sppo_loss_fn module
- trinity.algorithm.policy_loss_fn.topr_policy_loss module
- Module contents
- Submodules
- trinity.algorithm.sample_strategy package
Submodules#
- trinity.algorithm.algorithm module
- trinity.algorithm.algorithm_manager module
- trinity.algorithm.key_mapper module
- trinity.algorithm.utils module
Module contents#
- class trinity.algorithm.AlgorithmType[source]#
Bases:
ABC
- use_critic: bool#
- use_reference: bool#
- compute_advantage_in_trainer: bool#
- can_balance_batch: bool#
- schema: str#
- class trinity.algorithm.AdvantageFn[source]#
Bases:
ABC
- class trinity.algorithm.PolicyLossFn(backend: str = 'verl')[source]#
Bases:
ABC
Abstract base class for policy loss functions.
This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.
- __init__(backend: str = 'verl')[source]#
Initialize the policy loss function.
- Parameters:
backend – The training framework/backend to use (e.g., “verl”)
- abstract classmethod default_args() Dict [source]#
Get default initialization arguments for this loss function.
- Returns:
The default init arguments for the policy loss function.
- Return type:
Dict
- property select_keys#
Returns parameter keys mapped to the specific training framework’s naming convention.
- class trinity.algorithm.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[source]#
Bases:
ABC
KL penalty and loss.
- __init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None [source]#
- apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict] [source]#
Apply KL penalty to reward. Only support DataProto input for now.
- calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor) Tuple[Tensor, Dict] [source]#
Compute KL loss.
- class trinity.algorithm.SampleStrategy(buffer_config: BufferConfig, **kwargs)[source]#
Bases:
ABC
- __init__(buffer_config: BufferConfig, **kwargs) None [source]#
- abstract async sample(step: int) Tuple[Experiences, Dict, List] [source]#
Sample data from buffer.
- Parameters:
step (int) – The step number of current step.
- Returns:
The sampled Experiences data. Dict: Metrics for logging. List: Representative data for logging.
- Return type:
Experiences