Source code for trinity.buffer.selector.difficulty_estimator

from typing import List

import numpy as np

from trinity.utils.log import get_logger


[docs] class BaseBetaPREstimator: n: int m: int lamb: float rho: float alphas: np.ndarray betas: np.ndarray
[docs] def __init__(self, n: int, m: int = 16, lamb: float = 0.2, rho: float = 0.2): """ alpha_{t+1} = (1 - lamb) * alpha_t + (1 - rho) * bar{s} + rho * tilde{s} beta_{t+1} = (1 - lamb) beta_t + (1 - rho) * bar{f} + rho * tilde{f} Args: n (int): number of tasks. m (int): repeat times per tasks. timeout (lamb): discount factor of historical estimation. rho (float): weight of pseudo counts. """ self.n = n self.m = m self.lamb = lamb self.rho = rho self.alphas = np.ones(n, dtype=float) self.betas = np.ones(n, dtype=float) self.logger = get_logger("BetaPREstimator") self.logger.debug( f"{self.n=}, {self.m=}, {self.lamb=}, {self.rho=}, {self.alphas=}, {self.betas=}" )
[docs] def set(self, alphas, betas): self.alphas = alphas self.betas = betas
def _update(self, s_bar, f_bar, p_tilde): self.alphas = ( (1 - self.lamb) * self.alphas + self.lamb + (1 - self.rho) * s_bar + self.rho * p_tilde * self.m ) self.betas = ( (1 - self.lamb) * self.betas + self.lamb + (1 - self.rho) * f_bar + self.rho * (1 - p_tilde) * self.m )
[docs] def update(self, ref_indices: List[int], ref_pass_rates: List[float]): raise NotImplementedError
[docs] def predict_pr(self, rng=None, indices=None, do_sample=False): if rng is None: rng = np.random.default_rng() if indices is None: indices = np.arange(self.n) if not do_sample: return self.alphas[indices] / (self.alphas[indices] + self.betas[indices]) else: return rng.beta(self.alphas[indices], self.betas[indices])
[docs] def equivalent_count(self, indices=None): if indices is None: indices = np.arange(self.n) return self.alphas[indices] + self.betas[indices]
[docs] class InterpolationBetaPREstimator(BaseBetaPREstimator):
[docs] def __init__( self, features: np.ndarray, m: int, lamb, rho, cap_coef_update_discount=0.9, adaptive_rho=False, ): super(InterpolationBetaPREstimator, self).__init__(len(features), m, lamb, rho) self.features = features # [D, 2] self.cap_coef = None self.cap_coef_update_discount = cap_coef_update_discount self.adaptive_rho = adaptive_rho
[docs] def update(self, ref_indices: List[int], ref_pass_rates: List[float]): ref_pass_rate = np.mean(ref_pass_rates) ref_anchor_pass_rates = np.mean(self.features[ref_indices], axis=0) cap_estimate = (ref_pass_rate - ref_anchor_pass_rates[0]) / ( ref_anchor_pass_rates[1] - ref_anchor_pass_rates[0] + 1e-6 ) if self.cap_coef is None: self.cap_coef = cap_estimate else: self.cap_coef = ( self.cap_coef_update_discount * self.cap_coef + (1 - self.cap_coef_update_discount) * cap_estimate ) s_bar = np.zeros(self.n, dtype=float) s_bar[ref_indices] = np.array(ref_pass_rates) * self.m f_bar = np.zeros(self.n, dtype=float) f_bar[ref_indices] = (1 - np.array(ref_pass_rates)) * self.m p_tilde = np.clip( (self.features[:, 1] - self.features[:, 0]) * self.cap_coef + self.features[:, 0], 0, 1 ) predicted_pass_rates = p_tilde[ref_indices] mean_abs_error = np.mean(np.abs(np.array(predicted_pass_rates) - np.array(ref_pass_rates))) if self.adaptive_rho and mean_abs_error >= 0.25: self.rho = self.rho * 0.5 self.logger.debug(f"{mean_abs_error=}, {self.rho=}") p_tilde[ref_indices] = np.array(ref_pass_rates) self._update(s_bar, f_bar, p_tilde)