trinity.algorithm.utils module

Common utils for algorithm module.

Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py

trinity.algorithm.utils.masked_sum(values, mask, axis=None)[source]

Compute mean of tensor with a masked values.

trinity.algorithm.utils.masked_mean(values, mask, axis=None)[source]

Compute mean of tensor with a masked values.

trinity.algorithm.utils.masked_var(values, mask, unbiased=True)[source]

Compute variance of tensor with masked values.

trinity.algorithm.utils.masked_whiten(values, mask, shift_mean=True)[source]

Whiten values by normalizing with mean and variance computed over mask.

Parameters:
  • values (torch.Tensor) – Input tensor.

  • mask (torch.Tensor) – Boolean tensor of same shape, selects elements for stats.

  • shift_mean (bool) – If True (default), output is zero-mean; if False, the original mean is re-added after scaling.

Returns:

Whitened tensor of same shape as values.

Return type:

torch.Tensor

trinity.algorithm.utils.prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict | None = None) dict[source]