Source code for trinity.algorithm.utils

"""Common utils for algorithm module.

Modified from https://github.com/volcengine/verl/blob/main/verl/utils/torch_functional.py
"""

import torch


[docs] def masked_sum(values, mask, axis=None): """Compute mean of tensor with a masked values.""" return (values * mask).sum(axis=axis)
[docs] def masked_mean(values, mask, axis=None): """Compute mean of tensor with a masked values.""" return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)
[docs] def masked_var(values, mask, unbiased=True): """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) centered_values = values - mean variance = masked_mean(centered_values**2, mask) if unbiased: mask_sum = mask.sum() if mask_sum == 0: raise ValueError("At least one element in the mask has to be 1.") # note that if mask_sum == 1, then there is a division by zero issue # to avoid it you just need to use a larger minibatch_size if mask_sum == 1: raise ValueError("The sum of the mask is one, which can cause a division by zero.") bessel_correction = mask_sum / (mask_sum - 1) variance = variance * bessel_correction return variance
[docs] def masked_whiten(values, mask, shift_mean=True): """ Whiten `values` by normalizing with mean and variance computed over `mask`. Args: values (torch.Tensor): Input tensor. mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats. shift_mean (bool): If True (default), output is zero-mean; if False, the original mean is re-added after scaling. Returns: torch.Tensor: Whitened tensor of same shape as `values`. """ mean, var = masked_mean(values, mask), masked_var(values, mask) whitened = (values - mean) * torch.rsqrt(var + 1e-8) if not shift_mean: whitened += mean return whitened
[docs] def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dict: if dst_metrics is None: dst_metrics = {} for k, v in src_metrics.items(): dst_metrics[f"{prefix}/{k}"] = v return dst_metrics