"""Data selectors."""
from typing import Dict, List
import numpy as np
import torch
from trinity.buffer.reader.file_reader import _HFBatchReader
from trinity.buffer.selector.difficulty_estimator import InterpolationBetaPREstimator
from trinity.common.config import TaskSelectorConfig
from trinity.utils.annotations import Experimental
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
SELECTORS = Registry("selectors")
[docs]
@Experimental
class BaseSelector:
    """
    Abstract base class defining the interface for custom data selection strategies.
    A selector determines which samples (by index) are selected from the dataset
    during training. It enables flexible sampling beyond simple
    sequential or random access, supporting active learning, curriculum learning,
    or difficulty-based sampling in the future.
    Subclasses must implement:
        - get_indices: returns list of indices for next batch
        - update: updates internal state using feedback (e.g., loss values, mean rewards, etc.)
        - state_dict / load_state_dict: for saving/loading selector state (checkpointing)
    """
[docs]
    def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig):
        self.data_source = data_source
        self.config = config 
[docs]
    def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
        """
        Select a batch of sample indices from the dataset.
        Args:
            batch_size (int): Number of indices to return
            return_extra_info (bool): If True, may return additional metadata (future use)
        Returns:
            List[int]: Selected indices into the dataset
        """
        raise NotImplementedError 
[docs]
    def update(self, indices: List[int], values: List[float]) -> None:
        """
        Update internal state based on feedback (e.g., model loss, accuracy).
        This allows adaptive selectors (like hard example mining) to learn over time.
        Args:
            indices (List[int]): Previously selected indices
            values (List[float]): Feedback values corresponding to those indices
        """
        raise NotImplementedError 
[docs]
    def state_dict(self) -> Dict:
        """
        Return serializable state of the selector for checkpointing.
        Returns:
            Dict: State information (e.g., current position, etc.)
        """
        raise NotImplementedError 
[docs]
    def load_state_dict(self, state_dict: Dict) -> None:
        """
        Restore selector state from a saved dictionary.
        Args:
            state_dict (Dict): Output from state_dict()
        """
        raise NotImplementedError 
 
[docs]
@SELECTORS.register_module("sequential")
class SequentialSelector(BaseSelector):
    """
    Selects data sequentially in fixed order across epochs.
    Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc.
    """
[docs]
    def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig):
        super().__init__(data_source, config)
        self.dataset_size = data_source.dataset_size
        self.current_index = 0 
[docs]
    def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
        start = self.current_index % self.dataset_size
        end = start + batch_size
        self.current_index += batch_size
        if end <= self.dataset_size:
            return list(range(start, end))
        return list(range(start, self.dataset_size)) + list(range(0, end - self.dataset_size)) 
[docs]
    def update(self, indices: List[int], values: List[float]) -> None:
        # No-op: sequential selection doesn't adapt based on feedback
        pass 
[docs]
    def state_dict(self) -> Dict:
        return {
            "current_index": self.current_index,
        } 
[docs]
    def load_state_dict(self, state_dict):
        self.current_index = state_dict.get("current_index", 0) 
 
[docs]
@SELECTORS.register_module("shuffle")
class ShuffleSelector(BaseSelector):
    """
    Shuffles dataset once per epoch and iterates through it sequentially.
    Each epoch uses a different permutation of a subset of the full dataset.
    When one epoch ends, a new shuffle is triggered.
    Mimics standard PyTorch DataLoader with shuffle=True.
    """
[docs]
    def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig):
        super().__init__(data_source, config)
        self.dataset_size = data_source.dataset_size  # Total available samples
        self.current_index = 0  # Progress tracker
        self.seed = config.seed  # For reproducible shuffling
        self.orders = self._get_orders()  # Current shuffled index order 
    def _get_orders(self) -> List[int]:
        """
        Generate a new shuffled order for the current epoch.
        Uses NumPy's PCG64 random generator seeded by epoch number for reproducibility.
        Ensures different shuffle per epoch while being deterministic if seed is fixed.
        """
        rng = np.random.default_rng(self.seed + self.current_index // self.dataset_size)
        return rng.permutation(self.dataset_size).tolist()
[docs]
    def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[int]:
        start = self.current_index % self.dataset_size
        end = start + batch_size
        if end <= self.dataset_size:
            ret = self.orders[start:end]
            # At end of epoch, reshuffle for next epoch
            if end == self.dataset_size:
                self.orders = self._get_orders()
        else:
            ret = self.orders[start:]
            # At end of epoch, reshuffle for next epoch
            self.orders = self._get_orders()
            ret += self.orders[: (end - self.dataset_size)]
        self.current_index += batch_size
        return ret 
[docs]
    def update(self, indices: List[int], values: List[float]) -> None:
        # No-op: static shuffling does not adapt
        pass 
[docs]
    def state_dict(self) -> Dict:
        return {
            "current_index": self.current_index,
        } 
[docs]
    def load_state_dict(self, state_dict):
        self.current_index = state_dict.get("current_index", 0)
        self.orders = self._get_orders() 
 
[docs]
@SELECTORS.register_module("random")
class RandomSelector(BaseSelector):
    """
    Uniformly samples batches randomly with replacement *per batch*.
    Unlike ShuffleSelector, there is no concept of an epoch — every batch is independently sampled.
    Can result in repeated samples within an epoch. Suitable for online or stochastic training regimes.
    """
[docs]
    def __init__(self, data_source: _HFBatchReader, config: TaskSelectorConfig):
        super().__init__(data_source, config)
        self.dataset_size = data_source.dataset_size
        self.current_index = 0
        self.seed = config.seed 
[docs]
    def get_indices(self, batch_size, return_extra_info=False):
        # Seed varies per batch to ensure repeatability across runs
        rng = np.random.default_rng(self.seed + self.current_index)
        selected_indices = rng.choice(self.dataset_size, batch_size, replace=False)
        self.current_index += batch_size
        if return_extra_info:
            return selected_indices, {}
        else:
            return selected_indices 
[docs]
    def update(self, indices: List[int], values: List[float]) -> None:
        # No-op: basic random selection doesn't adapt
        pass 
[docs]
    def state_dict(self) -> Dict:
        return {
            "current_index": self.current_index,
        } 
[docs]
    def load_state_dict(self, state_dict):
        self.current_index = state_dict.get("current_index", 0) 
 
[docs]
@SELECTORS.register_module("offline_easy2hard")
class OfflineEasy2HardSelector(BaseSelector):
    """
    Selects samples in an 'easy-to-hard' curriculum based on pre-defined difficulty features.
    This selector assumes that higher feature values indicate easier examples.
    It sorts all data once at initialization by descending feature value(s), then sequentially
    serves batches from easy → hard over epochs. The sorting is fixed (offline), so no online
    adaptation occurs during training.
    Useful for curriculum learning where sample difficulty is estimated ahead of time
    (e.g., via teacher model confidence, length, BLEU score, etc.).
    """
[docs]
    def __init__(self, data_source, config: TaskSelectorConfig):
        super().__init__(data_source, config)
        self.logger = get_logger("offline_easy2hard_selector")
        # Extract specified feature columns (e.g., 'loss', 'confidence') used to estimate difficulty
        feature_keys = config.feature_keys
        self.features = np.concatenate(
            [np.array(list(data_source.dataset[k]))[:, None] for k in feature_keys], axis=1
        )
        # Shape: (N, len(feature_keys)) — one row per sample, one column per feature
        # Append index to each feature vector for tracking original positions after sorting
        features_with_index = [list(self.features[i]) + [i] for i in range(len(self.features))]
        # Sort by feature values in descending order → highest (easiest) first
        features_with_index = sorted(features_with_index)[::-1]
        self.logger.debug(f"OfflineEasy2HardSelector, sorted {features_with_index[:20]}")
        # Store the sorted order of indices (from easiest to hardest)
        self.sorted_index = np.array([i[-1] for i in features_with_index])
        # Number of samples per epoch (may be less than full dataset size)
        self.dataset_size = data_source.dataset_size
        self.current_index = 0 
[docs]
    def update(self, indices: List[int], values: List[float]) -> None:
        # No-op: this selector does not adapt based on runtime feedback
        pass 
[docs]
    def get_indices(self, batch_size, return_extra_info=False):
        """
        Returns next batch of indices in curriculum order (easy → hard).
        Batches are taken sequentially from the pre-sorted list. When epoch ends,
        it wraps around to the beginning (i.e., restarts curriculum).
        """
        start = self.current_index % self.dataset_size
        end = start + batch_size
        if end <= self.dataset_size:
            selected_indices = self.sorted_index[start:end]
        else:
            selected_indices = np.concatenate(
                [self.sorted_index[start:], self.sorted_index[: (end - self.dataset_size)]]
            )
        self.current_index += batch_size
        if not return_extra_info:
            return selected_indices
        else:
            extra_info = {
                "indices": selected_indices.tolist(),
                "feat1": self.features[selected_indices, 0].tolist(),
                "feat2": self.features[selected_indices, 1].tolist(),
            }
            return selected_indices, extra_info 
[docs]
    def state_dict(self) -> Dict:
        """
        Save current position in the curriculum for checkpointing.
        Allows resuming from same point in the easy→hard progression.
        """
        return {
            "current_index": self.current_index,
        } 
[docs]
    def load_state_dict(self, state_dict):
        """
        Restore progress through the curriculum from saved state.
        """
        self.current_index = state_dict.get("current_index", 0) 
 
[docs]
@SELECTORS.register_module("difficulty_based")
class DifficultyBasedSelector(BaseSelector):
    """
    Adaptive difficulty-based selector using probabilistic modeling of sample difficulty.
    Uses `InterpolationBetaPREstimator` to model each sample's probability of success (PR),
    updated with observed feedback (e.g., loss, accuracy). Then selects samples close to
    a target reward (e.g., 1.0 for perfect performance), implementing a form of
    *targeted difficulty sampling* — focusing on items near the edge of model capability.
    Supports both greedy selection (`tau=0`) and stochastic sampling (`tau>0`).
    """
[docs]
    def __init__(self, data_source, config: TaskSelectorConfig) -> None:
        super().__init__(data_source, config)
        self.logger = get_logger("difficulty_based_selector")
        # Initialize difficulty estimator using two features (assumed: e.g., correctness & uncertainty)
        self.diff_estimator = self.build_diff_estimator(
            data_source.dataset, config.feature_keys, config.kwargs
        )
        self.current_index = 0
        self.seed = config.seed
        self.do_sample = config.kwargs.get(
            "do_sample", False
        )  # Whether to sample PR during estimation
        self.target_reward = config.kwargs.get("target_reward", 1.0)  # Desired performance level
        self.tau = config.kwargs.get("tau", 1.0)  # Temperature for sampling distribution 
[docs]
    def build_diff_estimator(self, dataset, feature_keys: List[str], config: dict):
        """
        Constructs a Beta-distribution-based difficulty estimator from features.
        Expects exactly two feature keys (e.g., ['correct', 'uncertainty']), which are concatenated
        into a feature matrix and passed to InterpolationBetaPREstimator for modeling P(success).
        """
        self.logger.debug(f"{config=}")
        if len(feature_keys) != 2:
            raise ValueError(
                f"DifficultyBasedSelector requires exactly 2 feature keys, but got {len(feature_keys)}."
            )
        features = np.concatenate(
            [np.array(list(dataset[k]))[:, None] for k in feature_keys], axis=1
        )
        self.logger.debug(f"{features.shape=}")
        self.logger.debug(f"{features[:5]=}")
        adaptive_rho = config.get("adaptive_rho", False)
        return InterpolationBetaPREstimator(
            features=features,
            m=config.get("m", 16),
            lamb=config.get("lamb", 0.2),
            rho=config.get("rho", 0.2),
            adaptive_rho=adaptive_rho,
        ) 
[docs]
    def update(self, indices: List[int], values: List[float]) -> None:
        """
        Updates the difficulty estimator with observed performance on selected samples.
        Args:
            indices (List[int]): Previously selected sample indices
            values (List[float]): Observed rewards/scores (e.g., accuracy, BLEU) for those samples
        """
        self.diff_estimator.update(indices, values) 
[docs]
    def get_scores(self) -> List[float]:
        """
        Computes selection scores: negative distance between predicted PR and target reward.
        Samples whose predicted performance is closest to `target_reward` receive highest scores.
        Encourages selection of "just right" difficulty samples (neither too easy nor too hard).
        """
        rng = np.random.default_rng(self.seed + self.current_index)
        predicted_pr = self.diff_estimator.predict_pr(rng=rng, do_sample=self.do_sample)
        scores = -np.abs(self.target_reward - predicted_pr)
        return scores 
[docs]
    def get_indices(self, batch_size, return_extra_info=False):
        """
        Selects batch of indices based on difficulty proximity to target.
        If tau == 0: take top-k highest scoring samples (greedy).
        Else: sample stochastically using softmax(logits / tau).
        """
        sampling_scores = self.get_scores()
        sampling_scores = torch.from_numpy(sampling_scores)
        if self.tau == 0:
            selected_indices = torch.topk(sampling_scores, batch_size).indices
        else:
            sampling_logits = sampling_scores / self.tau
            sampling_logits -= sampling_logits.max()
            sampling_probabilities = torch.softmax(sampling_logits, dim=0)
            rng = torch.Generator()
            rng.manual_seed(self.seed + self.current_index)
            selected_indices = torch.multinomial(
                sampling_probabilities,
                batch_size,
                replacement=False,
                generator=rng,
            )
        self.logger.debug(f"{selected_indices=}")
        self.logger.debug(f"{sampling_scores=}")
        self.logger.debug(f"{sampling_scores[selected_indices]=}")
        self.current_index += batch_size
        if return_extra_info:
            selected_indices_list = selected_indices.tolist()
            alphas = self.diff_estimator.alphas[selected_indices_list]
            betas = self.diff_estimator.betas[selected_indices_list]
            point_est = alphas / (alphas + betas)
            extra_info = {
                "indices": selected_indices_list,
                "scores": sampling_scores[selected_indices].tolist(),
                "alphas": alphas.tolist(),
                "betas": betas.tolist(),
                "point": point_est.tolist(),
            }
            return selected_indices, extra_info
        else:
            return selected_indices 
[docs]
    def state_dict(self) -> Dict:
        """
        Save current state for checkpointing.
        Only tracks sampling progress; actual difficulty estimates are in diff_estimator.
        """
        return {
            "current_index": self.current_index,
        } 
[docs]
    def load_state_dict(self, state_dict):
        """
        Restore selector state from checkpoint.
        """
        self.current_index = state_dict.get("current_index", 0)