trinity.buffer.task_scheduler module#

The taskset scheduler.

trinity.buffer.task_scheduler.get_taskset_scheduler(explorer_state: Dict, config: Config) TasksetSchedulerBase[源代码]#

Get a taskset scheduler according to the config.

参数:
  • explorer_state (Dict) -- Restoration state from checkpoint (may include progress info)

  • config (Config) -- Full system configuration containing buffer and taskset settings

返回:

The taskset scheduler instance

返回类型:

TasksetSchedulerBase

class trinity.buffer.task_scheduler.TasksetSchedulerBase(explorer_state: Dict, config: Config)[源代码]#

基类:object

__init__(explorer_state: Dict, config: Config)[源代码]#
async read_async() List[源代码]#

Asynchronously reads a batch of tasks according to the current schedule.

state_dict() List[Dict][源代码]#

return persistent state for checkpointing.

返回:

State dicts for all selectors (one per taskset)

返回类型:

List[Dict]

update(pipeline_metrics: Dict) None[源代码]#

Update selectors using feedback from the training pipeline.

class trinity.buffer.task_scheduler.SimpleTasksetScheduler(explorer_state: Dict, config: Config)[源代码]#

基类:TasksetSchedulerBase

A simple taskset scheduler that only reads from one taskset without task selection strategies.

__init__(explorer_state: Dict, config: Config)[源代码]#
async read_async() List[源代码]#

Asynchronously reads a batch of tasks according to the current schedule.

state_dict() List[Dict][源代码]#

return persistent state for checkpointing.

返回:

State dicts for all selectors (one per taskset)

返回类型:

List[Dict]

update(pipeline_metrics: Dict) None[源代码]#

Update selectors using feedback from the training pipeline.

class trinity.buffer.task_scheduler.TasksetScheduler(explorer_state: Dict, config: Config)[源代码]#

基类:TasksetSchedulerBase

Coordinates multiple datasets (tasksets) with customizable task selection strategies per taskset.

The scheduler:
  • Manages multiple data sources (tasksets)

  • Uses a selector per taskset to determine which samples to read

  • Shuffles the order of taskset access across epochs

  • Supports adaptive selectors via feedback (e.g., difficulty-based sampling)

  • Enables curriculum-like or interleaved multi-task training

It assumes that each call to read_async() corresponds to one training step, and batches are built by aggregating samples from different tasksets based on a shuffled global schedule.

__init__(explorer_state: Dict, config: Config)[源代码]#

Initialize the scheduler from configuration and previous state (for resume support).

参数:
  • explorer_state (Dict) -- Restoration state from checkpoint (may include progress info)

  • config (Config) -- Full system configuration containing buffer and taskset settings

build_orders(epoch: int)[源代码]#

Creates a shuffled sequence of taskset IDs to control sampling priority per step.

At the start of each epoch, all tasksets are shuffled proportionally to their size, ensuring balanced exposure while introducing randomness in selection order.

参数:

epoch (int) -- Epoch ID used as seed for deterministic shuffling

返回:

Sequence of taskset IDs, length = steps_per_epoch * batch_size

返回类型:

List[int]

async read_async() List[源代码]#

Asynchronously reads a batch of tasks according to the current schedule.

For each step:
  • Checks if a new epoch has started; rebuilds order if so

  • Determines which tasksets contribute to this batch

  • Uses each taskset's selector to pick specific samples

  • Annotates each task with its source taskset_id

  • Returns combined list of tasks

抛出:

StopAsyncIteration -- When total_epochs is reached

返回:

A batch of tasks from potentially multiple tasksets

返回类型:

List[Task]

state_dict() List[Dict][源代码]#

Save persistent state for checkpointing.

返回:

State dicts for all selectors (one per taskset)

返回类型:

List[Dict]

update(pipeline_metrics: Dict) None[源代码]#

Update selectors using feedback from the training pipeline.

Expected format:
pipeline_metrics = {
SELECTOR_METRIC: {

0: {"indices": [...], "values": [...]}, 1: {"indices": [...], "values": [...]}

}, ... # other metrics

}

This allows adaptive selectors (like DifficultyBasedSelector) to refine difficulty estimates.

参数:

pipeline_metrics (Dict) -- Metrics dictionary passed from explorer.