trinity.algorithm.utils module#
Common utils for algorithm module.
Modified from volcengine/verl
- trinity.algorithm.utils.masked_loss(values, mask, loss_agg_mode='token-mean', normalizer=None)[source]#
Compute loss from values and mask with various aggregation modes. Modified from: volcengine/verl
- Parameters:
values (torch.Tensor) – Arbitrary shape tensor of values to aggregate.
mask (torch.BoolTensor or torch.FloatTensor) – Same shape as values, 1/True = include, 0 = ignore.
loss_agg_mode (str) – One of the following: - “token-mean”: mean over all unmasked elements. - “seq-mean-token-sum”: average over sequences, where each sequence’s loss is sum of unmasked values. - “seq-mean-token-mean”: average over sequences, where each sequence’s loss is mean of unmasked values. - “seq-mean-token-sum-norm”: total sum of unmasked values divided by a fixed normalizer (e.g., seq length).
normalizer (float or None) – Only used in ‘seq-mean-token-sum-norm’. If None, uses mask.shape[-1].
- Returns:
Scalar loss value.
- Return type:
torch.Tensor
- trinity.algorithm.utils.masked_sum(values, mask, axis=None)[source]#
Compute mean of tensor with a masked values.
- trinity.algorithm.utils.masked_mean(values, mask, axis=None)[source]#
Compute mean of tensor with a masked values.
- trinity.algorithm.utils.masked_var(values, mask, unbiased=True)[source]#
Compute variance of tensor with masked values.
- trinity.algorithm.utils.masked_whiten(values, mask, shift_mean=True)[source]#
Whiten values by normalizing with mean and variance computed over mask.
- Parameters:
values (torch.Tensor) – Input tensor.
mask (torch.Tensor) – Boolean tensor of same shape, selects elements for stats.
shift_mean (bool) – If True (default), output is zero-mean; if False, the original mean is re-added after scaling.
- Returns:
Whitened tensor of same shape as values.
- Return type:
torch.Tensor