trinity.algorithm.utils module
Common utils for algorithm module.
Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py
- trinity.algorithm.utils.masked_sum(values, mask, axis=None)[source]
Compute mean of tensor with a masked values.
- trinity.algorithm.utils.masked_mean(values, mask, axis=None)[source]
Compute mean of tensor with a masked values.
- trinity.algorithm.utils.masked_var(values, mask, unbiased=True)[source]
Compute variance of tensor with masked values.
- trinity.algorithm.utils.masked_whiten(values, mask, shift_mean=True)[source]
Whiten values by normalizing with mean and variance computed over mask.
- Parameters:
values (torch.Tensor) – Input tensor.
mask (torch.Tensor) – Boolean tensor of same shape, selects elements for stats.
shift_mean (bool) – If True (default), output is zero-mean; if False, the original mean is re-added after scaling.
- Returns:
Whitened tensor of same shape as values.
- Return type:
torch.Tensor