trinity.algorithm.utils module#

Common utils for algorithm module.

Modified from volcengine/verl

trinity.algorithm.utils.aggregate_loss(values, mask, loss_agg_mode='token-mean', normalizer=None)[源代码]#

Compute loss from values and mask with various aggregation modes. Modified from: volcengine/verl

参数:
  • values (torch.Tensor) -- Arbitrary shape tensor of values to aggregate.

  • mask (torch.BoolTensor or torch.FloatTensor) -- Same shape as values, 1/True = include, 0 = ignore.

  • loss_agg_mode (str) -- One of the following: - "token-mean": mean over all unmasked elements. - "seq-mean-token-sum": average over sequences, where each sequence's loss is sum of unmasked values. - "seq-mean-token-mean": average over sequences, where each sequence's loss is mean of unmasked values. - "seq-mean-token-sum-norm": total sum of unmasked values divided by a fixed normalizer (e.g., seq length).

  • normalizer (float or None) -- Only used in 'seq-mean-token-sum-norm'. If None, uses mask.shape[-1].

返回:

Scalar loss value.

返回类型:

torch.Tensor

trinity.algorithm.utils.masked_sum(values, mask, axis=None)[源代码]#

Compute mean of tensor with a masked values.

trinity.algorithm.utils.masked_mean(values, mask, axis=None)[源代码]#

Compute mean of tensor with a masked values.

trinity.algorithm.utils.masked_var(values, mask, unbiased=True)[源代码]#

Compute variance of tensor with masked values.

trinity.algorithm.utils.masked_whiten(values, mask, shift_mean=True)[源代码]#

Whiten values by normalizing with mean and variance computed over mask.

参数:
  • values (torch.Tensor) -- Input tensor.

  • mask (torch.Tensor) -- Boolean tensor of same shape, selects elements for stats.

  • shift_mean (bool) -- If True (default), output is zero-mean; if False, the original mean is re-added after scaling.

返回:

Whitened tensor of same shape as values.

返回类型:

torch.Tensor

trinity.algorithm.utils.prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) dict[源代码]#