from typing import List
import numpy as np
from trinity.utils.log import get_logger
[docs]
class BaseBetaPREstimator:
    n: int
    m: int
    lamb: float
    rho: float
    alphas: np.ndarray
    betas: np.ndarray
[docs]
    def __init__(self, n: int, m: int = 16, lamb: float = 0.2, rho: float = 0.2):
        """
        alpha_{t+1} = (1 - lamb) * alpha_t + (1 - rho) * bar{s} + rho * tilde{s}
        beta_{t+1} = (1 - lamb) beta_t + (1 - rho) * bar{f} + rho * tilde{f}
        Args:
            n (int): number of tasks.
            m (int): repeat times per tasks.
            timeout (lamb): discount factor of historical estimation.
            rho (float): weight of pseudo counts.
        """
        self.n = n
        self.m = m
        self.lamb = lamb
        self.rho = rho
        self.alphas = np.ones(n, dtype=float)
        self.betas = np.ones(n, dtype=float)
        self.logger = get_logger("BetaPREstimator")
        self.logger.debug(
            f"{self.n=}, {self.m=}, {self.lamb=}, {self.rho=}, {self.alphas=}, {self.betas=}"
        ) 
[docs]
    def set(self, alphas, betas):
        self.alphas = alphas
        self.betas = betas 
    def _update(self, s_bar, f_bar, p_tilde):
        self.alphas = (
            (1 - self.lamb) * self.alphas
            + self.lamb
            + (1 - self.rho) * s_bar
            + self.rho * p_tilde * self.m
        )
        self.betas = (
            (1 - self.lamb) * self.betas
            + self.lamb
            + (1 - self.rho) * f_bar
            + self.rho * (1 - p_tilde) * self.m
        )
[docs]
    def update(self, ref_indices: List[int], ref_pass_rates: List[float]):
        raise NotImplementedError 
[docs]
    def predict_pr(self, rng=None, indices=None, do_sample=False):
        if rng is None:
            rng = np.random.default_rng()
        if indices is None:
            indices = np.arange(self.n)
        if not do_sample:
            return self.alphas[indices] / (self.alphas[indices] + self.betas[indices])
        else:
            return rng.beta(self.alphas[indices], self.betas[indices]) 
[docs]
    def equivalent_count(self, indices=None):
        if indices is None:
            indices = np.arange(self.n)
        return self.alphas[indices] + self.betas[indices] 
 
[docs]
class InterpolationBetaPREstimator(BaseBetaPREstimator):
[docs]
    def __init__(
        self,
        features: np.ndarray,
        m: int,
        lamb,
        rho,
        cap_coef_update_discount=0.9,
        adaptive_rho=False,
    ):
        super(InterpolationBetaPREstimator, self).__init__(len(features), m, lamb, rho)
        self.features = features  # [D, 2]
        self.cap_coef = None
        self.cap_coef_update_discount = cap_coef_update_discount
        self.adaptive_rho = adaptive_rho 
[docs]
    def update(self, ref_indices: List[int], ref_pass_rates: List[float]):
        ref_pass_rate = np.mean(ref_pass_rates)
        ref_anchor_pass_rates = np.mean(self.features[ref_indices], axis=0)
        cap_estimate = (ref_pass_rate - ref_anchor_pass_rates[0]) / (
            ref_anchor_pass_rates[1] - ref_anchor_pass_rates[0] + 1e-6
        )
        if self.cap_coef is None:
            self.cap_coef = cap_estimate
        else:
            self.cap_coef = (
                self.cap_coef_update_discount * self.cap_coef
                + (1 - self.cap_coef_update_discount) * cap_estimate
            )
        s_bar = np.zeros(self.n, dtype=float)
        s_bar[ref_indices] = np.array(ref_pass_rates) * self.m
        f_bar = np.zeros(self.n, dtype=float)
        f_bar[ref_indices] = (1 - np.array(ref_pass_rates)) * self.m
        p_tilde = np.clip(
            (self.features[:, 1] - self.features[:, 0]) * self.cap_coef + self.features[:, 0], 0, 1
        )
        predicted_pass_rates = p_tilde[ref_indices]
        mean_abs_error = np.mean(np.abs(np.array(predicted_pass_rates) - np.array(ref_pass_rates)))
        if self.adaptive_rho and mean_abs_error >= 0.25:
            self.rho = self.rho * 0.5
        self.logger.debug(f"{mean_abs_error=}, {self.rho=}")
        p_tilde[ref_indices] = np.array(ref_pass_rates)
        self._update(s_bar, f_bar, p_tilde)