trinity.algorithm.utils module#
Common utils for algorithm module.
Modified from volcengine/verl
- trinity.algorithm.utils.aggregate_loss(values, mask, loss_agg_mode='token-mean', normalizer=None)[源代码]#
Compute loss from values and mask with various aggregation modes. Modified from: volcengine/verl
- 参数:
values (torch.Tensor) -- Arbitrary shape tensor of values to aggregate.
mask (torch.BoolTensor or torch.FloatTensor) -- Same shape as values, 1/True = include, 0 = ignore.
loss_agg_mode (str) -- One of the following: - "token-mean": mean over all unmasked elements. - "seq-mean-token-sum": average over sequences, where each sequence's loss is sum of unmasked values. - "seq-mean-token-mean": average over sequences, where each sequence's loss is mean of unmasked values. - "seq-mean-token-sum-norm": total sum of unmasked values divided by a fixed normalizer (e.g., seq length).
normalizer (float or None) -- Only used in 'seq-mean-token-sum-norm'. If None, uses mask.shape[-1].
- 返回:
Scalar loss value.
- 返回类型:
torch.Tensor
- trinity.algorithm.utils.masked_sum(values, mask, axis=None)[源代码]#
Compute mean of tensor with a masked values.
- trinity.algorithm.utils.masked_mean(values, mask, axis=None)[源代码]#
Compute mean of tensor with a masked values.
- trinity.algorithm.utils.masked_var(values, mask, unbiased=True)[源代码]#
Compute variance of tensor with masked values.
- trinity.algorithm.utils.masked_whiten(values, mask, shift_mean=True)[源代码]#
Whiten values by normalizing with mean and variance computed over mask.
- 参数:
values (torch.Tensor) -- Input tensor.
mask (torch.Tensor) -- Boolean tensor of same shape, selects elements for stats.
shift_mean (bool) -- If True (default), output is zero-mean; if False, the original mean is re-added after scaling.
- 返回:
Whitened tensor of same shape as values.
- 返回类型:
torch.Tensor