trinity.algorithm.utils module#

Common utils for algorithm module.

Modified from volcengine/verl

trinity.algorithm.utils.masked_loss(values, mask, loss_agg_mode='token-mean', normalizer=None)[source]#

Compute loss from values and mask with various aggregation modes. Modified from: volcengine/verl

Parameters:
  • values (torch.Tensor) – Arbitrary shape tensor of values to aggregate.

  • mask (torch.BoolTensor or torch.FloatTensor) – Same shape as values, 1/True = include, 0 = ignore.

  • loss_agg_mode (str) – One of the following: - “token-mean”: mean over all unmasked elements. - “seq-mean-token-sum”: average over sequences, where each sequence’s loss is sum of unmasked values. - “seq-mean-token-mean”: average over sequences, where each sequence’s loss is mean of unmasked values. - “seq-mean-token-sum-norm”: total sum of unmasked values divided by a fixed normalizer (e.g., seq length).

  • normalizer (float or None) – Only used in ‘seq-mean-token-sum-norm’. If None, uses mask.shape[-1].

Returns:

Scalar loss value.

Return type:

torch.Tensor

trinity.algorithm.utils.masked_sum(values, mask, axis=None)[source]#

Compute mean of tensor with a masked values.

trinity.algorithm.utils.masked_mean(values, mask, axis=None)[source]#

Compute mean of tensor with a masked values.

trinity.algorithm.utils.masked_var(values, mask, unbiased=True)[source]#

Compute variance of tensor with masked values.

trinity.algorithm.utils.masked_whiten(values, mask, shift_mean=True)[source]#

Whiten values by normalizing with mean and variance computed over mask.

Parameters:
  • values (torch.Tensor) – Input tensor.

  • mask (torch.Tensor) – Boolean tensor of same shape, selects elements for stats.

  • shift_mean (bool) – If True (default), output is zero-mean; if False, the original mean is re-added after scaling.

Returns:

Whitened tensor of same shape as values.

Return type:

torch.Tensor

trinity.algorithm.utils.prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict | None = None) dict[source]#