trinity.buffer.task_scheduler module#

The taskset scheduler.

class trinity.buffer.task_scheduler.TasksetScheduler(explorer_state: Dict, config: Config)[source]#

Bases: object

Coordinates multiple datasets (tasksets) with customizable task selection strategies per taskset.

The scheduler:
  • Manages multiple data sources (tasksets)

  • Uses a selector per taskset to determine which samples to read

  • Shuffles the order of taskset access across epochs

  • Supports adaptive selectors via feedback (e.g., difficulty-based sampling)

  • Enables curriculum-like or interleaved multi-task training

It assumes that each call to read_async() corresponds to one training step, and batches are built by aggregating samples from different tasksets based on a shuffled global schedule.

__init__(explorer_state: Dict, config: Config)[source]#

Initialize the scheduler from configuration and previous state (for resume support).

Parameters:
  • explorer_state (Dict) – Restoration state from checkpoint (may include progress info)

  • config (Config) – Full system configuration containing buffer and taskset settings

build_orders(epoch: int)[source]#

Creates a shuffled sequence of taskset IDs to control sampling priority per step.

At the start of each epoch, all tasksets are shuffled proportionally to their size, ensuring balanced exposure while introducing randomness in selection order.

Parameters:

epoch (int) – Epoch ID used as seed for deterministic shuffling

Returns:

Sequence of taskset IDs, length = steps_per_epoch * batch_size

Return type:

List[int]

async read_async() List[source]#

Asynchronously reads a batch of tasks according to the current schedule.

For each step:
  • Checks if a new epoch has started; rebuilds order if so

  • Determines which tasksets contribute to this batch

  • Uses each taskset’s selector to pick specific samples

  • Annotates each task with its source taskset_id

  • Returns combined list of tasks

Raises:

StopAsyncIteration – When total_epochs is reached

Returns:

A batch of tasks from potentially multiple tasksets

Return type:

List[Task]

state_dict() List[Dict][source]#

Save persistent state for checkpointing.

Returns:

State dicts for all selectors (one per taskset)

Return type:

List[Dict]

update(pipeline_metrics: Dict) None[source]#

Update selectors using feedback from the training pipeline.

Expected format:
pipeline_metrics = {
SELECTOR_METRIC: {

0: {“indices”: […], “values”: […]}, 1: {“indices”: […], “values”: […]}

}, … # other metrics

}

This allows adaptive selectors (like DifficultyBasedSelector) to refine difficulty estimates.

Parameters:

pipeline_metrics (Dict) – Metrics dictionary passed from explorer.