trinity.buffer.selector.selector module#

Data selectors.

class trinity.buffer.selector.selector.BaseSelector(data_source: _HFBatchReader, config: TaskSelectorConfig)[源代码]#

基类:object

Abstract base class defining the interface for custom data selection strategies.

A selector determines which samples (by index) are selected from the dataset during training. It enables flexible sampling beyond simple sequential or random access, supporting active learning, curriculum learning, or difficulty-based sampling in the future.

Subclasses must implement:
  • get_indices: returns list of indices for next batch

  • update: updates internal state using feedback (e.g., loss values, mean rewards, etc.)

  • state_dict / load_state_dict: for saving/loading selector state (checkpointing)

__init__(data_source: _HFBatchReader, config: TaskSelectorConfig)[源代码]#
get_indices(batch_size: int, return_extra_info: bool = False) List[int][源代码]#

Select a batch of sample indices from the dataset.

参数:
  • batch_size (int) -- Number of indices to return

  • return_extra_info (bool) -- If True, may return additional metadata (future use)

返回:

Selected indices into the dataset

返回类型:

List[int]

update(indices: List[int], values: List[float]) None[源代码]#

Update internal state based on feedback (e.g., model loss, accuracy).

This allows adaptive selectors (like hard example mining) to learn over time.

参数:
  • indices (List[int]) -- Previously selected indices

  • values (List[float]) -- Feedback values corresponding to those indices

state_dict() Dict[源代码]#

Return serializable state of the selector for checkpointing.

返回:

State information (e.g., current position, etc.)

返回类型:

Dict

load_state_dict(state_dict: Dict) None[源代码]#

Restore selector state from a saved dictionary.

参数:

state_dict (Dict) -- Output from state_dict()

class trinity.buffer.selector.selector.SequentialSelector(data_source: _HFBatchReader, config: TaskSelectorConfig)[源代码]#

基类:BaseSelector

Selects data sequentially in fixed order across epochs.

Example: [0,1,2,...,B-1], then [B,B+1,...,2B-1], etc.

__init__(data_source: _HFBatchReader, config: TaskSelectorConfig)[源代码]#
get_indices(batch_size: int, return_extra_info: bool = False) List[int][源代码]#

Select a batch of sample indices from the dataset.

参数:
  • batch_size (int) -- Number of indices to return

  • return_extra_info (bool) -- If True, may return additional metadata (future use)

返回:

Selected indices into the dataset

返回类型:

List[int]

update(indices: List[int], values: List[float]) None[源代码]#

Update internal state based on feedback (e.g., model loss, accuracy).

This allows adaptive selectors (like hard example mining) to learn over time.

参数:
  • indices (List[int]) -- Previously selected indices

  • values (List[float]) -- Feedback values corresponding to those indices

state_dict() Dict[源代码]#

Return serializable state of the selector for checkpointing.

返回:

State information (e.g., current position, etc.)

返回类型:

Dict

load_state_dict(state_dict)[源代码]#

Restore selector state from a saved dictionary.

参数:

state_dict (Dict) -- Output from state_dict()

class trinity.buffer.selector.selector.ShuffleSelector(data_source: _HFBatchReader, config: TaskSelectorConfig)[源代码]#

基类:BaseSelector

Shuffles dataset once per epoch and iterates through it sequentially.

Each epoch uses a different permutation of a subset of the full dataset. When one epoch ends, a new shuffle is triggered. Mimics standard PyTorch DataLoader with shuffle=True.

__init__(data_source: _HFBatchReader, config: TaskSelectorConfig)[源代码]#
get_indices(batch_size: int, return_extra_info: bool = False) List[int][源代码]#

Select a batch of sample indices from the dataset.

参数:
  • batch_size (int) -- Number of indices to return

  • return_extra_info (bool) -- If True, may return additional metadata (future use)

返回:

Selected indices into the dataset

返回类型:

List[int]

update(indices: List[int], values: List[float]) None[源代码]#

Update internal state based on feedback (e.g., model loss, accuracy).

This allows adaptive selectors (like hard example mining) to learn over time.

参数:
  • indices (List[int]) -- Previously selected indices

  • values (List[float]) -- Feedback values corresponding to those indices

state_dict() Dict[源代码]#

Return serializable state of the selector for checkpointing.

返回:

State information (e.g., current position, etc.)

返回类型:

Dict

load_state_dict(state_dict)[源代码]#

Restore selector state from a saved dictionary.

参数:

state_dict (Dict) -- Output from state_dict()

class trinity.buffer.selector.selector.RandomSelector(data_source: _HFBatchReader, config: TaskSelectorConfig)[源代码]#

基类:BaseSelector

Uniformly samples batches randomly with replacement per batch.

Unlike ShuffleSelector, there is no concept of an epoch — every batch is independently sampled. Can result in repeated samples within an epoch. Suitable for online or stochastic training regimes.

__init__(data_source: _HFBatchReader, config: TaskSelectorConfig)[源代码]#
get_indices(batch_size, return_extra_info=False)[源代码]#

Select a batch of sample indices from the dataset.

参数:
  • batch_size (int) -- Number of indices to return

  • return_extra_info (bool) -- If True, may return additional metadata (future use)

返回:

Selected indices into the dataset

返回类型:

List[int]

update(indices: List[int], values: List[float]) None[源代码]#

Update internal state based on feedback (e.g., model loss, accuracy).

This allows adaptive selectors (like hard example mining) to learn over time.

参数:
  • indices (List[int]) -- Previously selected indices

  • values (List[float]) -- Feedback values corresponding to those indices

state_dict() Dict[源代码]#

Return serializable state of the selector for checkpointing.

返回:

State information (e.g., current position, etc.)

返回类型:

Dict

load_state_dict(state_dict)[源代码]#

Restore selector state from a saved dictionary.

参数:

state_dict (Dict) -- Output from state_dict()

class trinity.buffer.selector.selector.OfflineEasy2HardSelector(data_source, config: TaskSelectorConfig)[源代码]#

基类:BaseSelector

Selects samples in an 'easy-to-hard' curriculum based on pre-defined difficulty features.

This selector assumes that higher feature values indicate easier examples. It sorts all data once at initialization by descending feature value(s), then sequentially serves batches from easy → hard over epochs. The sorting is fixed (offline), so no online adaptation occurs during training.

Useful for curriculum learning where sample difficulty is estimated ahead of time (e.g., via teacher model confidence, length, BLEU score, etc.).

__init__(data_source, config: TaskSelectorConfig)[源代码]#
update(indices: List[int], values: List[float]) None[源代码]#

Update internal state based on feedback (e.g., model loss, accuracy).

This allows adaptive selectors (like hard example mining) to learn over time.

参数:
  • indices (List[int]) -- Previously selected indices

  • values (List[float]) -- Feedback values corresponding to those indices

get_indices(batch_size, return_extra_info=False)[源代码]#

Returns next batch of indices in curriculum order (easy → hard).

Batches are taken sequentially from the pre-sorted list. When epoch ends, it wraps around to the beginning (i.e., restarts curriculum).

state_dict() Dict[源代码]#

Save current position in the curriculum for checkpointing. Allows resuming from same point in the easy→hard progression.

load_state_dict(state_dict)[源代码]#

Restore progress through the curriculum from saved state.

class trinity.buffer.selector.selector.DifficultyBasedSelector(data_source, config: TaskSelectorConfig)[源代码]#

基类:BaseSelector

Adaptive difficulty-based selector using probabilistic modeling of sample difficulty.

Uses InterpolationBetaPREstimator to model each sample's probability of success (PR), updated with observed feedback (e.g., loss, accuracy). Then selects samples close to a target reward (e.g., 1.0 for perfect performance), implementing a form of targeted difficulty sampling — focusing on items near the edge of model capability.

Supports both greedy selection (tau=0) and stochastic sampling (tau>0).

__init__(data_source, config: TaskSelectorConfig) None[源代码]#
build_diff_estimator(dataset, feature_keys: List[str], config: dict)[源代码]#

Constructs a Beta-distribution-based difficulty estimator from features.

Expects exactly two feature keys (e.g., ['correct', 'uncertainty']), which are concatenated into a feature matrix and passed to InterpolationBetaPREstimator for modeling P(success).

update(indices: List[int], values: List[float]) None[源代码]#

Updates the difficulty estimator with observed performance on selected samples.

参数:
  • indices (List[int]) -- Previously selected sample indices

  • values (List[float]) -- Observed rewards/scores (e.g., accuracy, BLEU) for those samples

get_scores() List[float][源代码]#

Computes selection scores: negative distance between predicted PR and target reward.

Samples whose predicted performance is closest to target_reward receive highest scores. Encourages selection of "just right" difficulty samples (neither too easy nor too hard).

get_indices(batch_size, return_extra_info=False)[源代码]#

Selects batch of indices based on difficulty proximity to target.

If tau == 0: take top-k highest scoring samples (greedy). Else: sample stochastically using softmax(logits / tau).

state_dict() Dict[源代码]#

Save current state for checkpointing. Only tracks sampling progress; actual difficulty estimates are in diff_estimator.

load_state_dict(state_dict)[源代码]#

Restore selector state from checkpoint.