trinity.algorithm package#

Subpackages#

Submodules#

Module contents#

class trinity.algorithm.AlgorithmType[源代码]#

基类:ABC

classmethod check_config(config: Config) None[源代码]#
abstractmethod classmethod default_config() Dict[源代码]#
classmethod name() str[源代码]#
use_critic: bool#
use_reference: bool#
compute_advantage_in_trainer: bool#
can_balance_batch: bool#
schema: str#
class trinity.algorithm.AdvantageFn[源代码]#

基类:ABC

classmethod compute_in_trainer() bool[源代码]#

Whether the advantage should be computed in the trainer loop.

abstractmethod classmethod default_args() Dict[源代码]#
返回:

The default init arguments for the advantage function.

返回类型:

Dict

class trinity.algorithm.PolicyLossFn(backend: str = 'verl')[源代码]#

基类:ABC

Abstract base class for policy loss functions.

This class provides the interface for implementing different policy gradient loss functions while handling parameter name mapping between different training frameworks.

__init__(backend: str = 'verl')[源代码]#

Initialize the policy loss function.

参数:

backend -- The training framework/backend to use (e.g., "verl")

abstractmethod classmethod default_args() Dict[源代码]#

Get default initialization arguments for this loss function.

返回:

The default init arguments for the policy loss function.

返回类型:

Dict

property select_keys#

Returns parameter keys mapped to the specific training framework's naming convention.

class trinity.algorithm.KLFn(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None)[源代码]#

基类:ABC

KL penalty and loss.

__init__(adaptive: bool = False, kl_coef: float = 0.001, target_kl: float | None = None, horizon: float | None = None) None[源代码]#
apply_kl_penalty_to_reward(experiences: Any) Tuple[Any, Dict][源代码]#

Apply KL penalty to reward. Only support DataProto input for now.

abstractmethod calculate_kl(logprob: Tensor, ref_logprob: Tensor, old_logprob: Tensor | None = None) Tensor[源代码]#

Compute KL divergence between logprob and ref_logprob.

参数:
  • logprob -- Log probabilities from current policy

  • ref_logprob -- Log probabilities from reference policy

  • old_logprob -- Log probabilities from old policy (for importance sampling)

calculate_kl_loss(logprob: Tensor, ref_logprob: Tensor, response_mask: Tensor, loss_agg_mode: str, old_logprob: Tensor | None = None) Tuple[Tensor, Dict][源代码]#

Compute KL loss.

参数:
  • logprob -- Log probabilities from current policy

  • ref_logprob -- Log probabilities from reference policy

  • response_mask -- Mask for valid response tokens

  • loss_agg_mode -- Loss aggregation mode

  • old_logprob -- Log probabilities from old policy (for importance sampling)

classmethod default_args()[源代码]#

Get the default initialization arguments.

update_kl_coef(current_kl: float, batch_size: int) None[源代码]#

Update kl coefficient.

class trinity.algorithm.EntropyLossFn[源代码]#

基类:ABC

Entropy loss function.

classmethod default_args() Dict[源代码]#
返回:

The default arguments for the entropy loss function.

返回类型:

Dict

class trinity.algorithm.SampleStrategy(buffer_config: BufferConfig, **kwargs)[源代码]#

基类:ABC

__init__(buffer_config: BufferConfig, **kwargs) None[源代码]#
abstractmethod classmethod default_args() dict[源代码]#

Get the default arguments of the sample strategy.

abstractmethod load_state_dict(state_dict: dict) None[源代码]#

Load the state dict of the sample strategy.

abstractmethod async sample(step: int) Tuple[List[Experience], Dict, List][源代码]#

Sample data from buffer.

参数:

step (int) -- The step number of current step.

返回:

The sampled List[Experience] data. Dict: Metrics for logging. List: Representative data for logging.

返回类型:

List[Experience]

set_model_version_metric(exp_list: List[Experience], metrics: Dict)[源代码]#
abstractmethod state_dict() dict[源代码]#

Get the state dict of the sample strategy.