"""Common utils for algorithm module.
Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py
"""
import torch
[docs]
def masked_sum(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis)
[docs]
def masked_mean(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)
[docs]
def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError("At least one element in the mask has to be 1.")
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
if mask_sum == 1:
raise ValueError("The sum of the mask is one, which can cause a division by zero.")
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance
[docs]
def masked_whiten(values, mask, shift_mean=True):
"""
Whiten `values` by normalizing with mean and variance computed over `mask`.
Args:
values (torch.Tensor): Input tensor.
mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats.
shift_mean (bool): If True (default), output is zero-mean;
if False, the original mean is re-added after scaling.
Returns:
torch.Tensor: Whitened tensor of same shape as `values`.
"""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
[docs]
def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dict:
if dst_metrics is None:
dst_metrics = {}
for k, v in src_metrics.items():
dst_metrics[f"{prefix}/{k}"] = v
return dst_metrics