trinity.algorithm.policy_loss_fn.importance_sampling_policy_loss module#

The most simple Importance Sampling policy loss.

loss = -(prob_ratio * advantages).sum() where prob_ratio = exp(current_logprobs - sampling_logprobs)

Note: This loss is used for on-policy distillation.

class trinity.algorithm.policy_loss_fn.importance_sampling_policy_loss.ImportanceSamplingLossFn(backend: str = 'verl', loss_agg_mode: str = 'token-mean')[源代码]#

基类:PolicyLossFn

Pure importance sampling loss without clipping.

loss = -(ratio * advantages) where ratio = exp(logprob - old_logprob)

__init__(backend: str = 'verl', loss_agg_mode: str = 'token-mean') None[源代码]#

Initialize the policy loss function.

参数:

backend -- The training framework/backend to use (e.g., "verl")

classmethod default_args() Dict[源代码]#

Get default initialization arguments for this loss function.

返回:

The default init arguments for the policy loss function.

返回类型:

Dict

property select_keys#

Returns parameter keys mapped to the specific training framework's naming convention.