Source code for trinity.buffer.task_scheduler
# -*- coding: utf-8 -*-
"""The taskset scheduler."""
from collections import Counter
from typing import Dict, List
import numpy as np
from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.selector import SELECTORS
from trinity.common.config import Config
from trinity.common.constants import SELECTOR_METRIC
from trinity.utils.annotations import Experimental
[docs]
@Experimental
class TasksetScheduler:
"""
Coordinates multiple datasets (tasksets) with customizable task selection strategies per taskset.
The scheduler:
- Manages multiple data sources (tasksets)
- Uses a selector per taskset to determine which samples to read
- Shuffles the order of taskset access across epochs
- Supports adaptive selectors via feedback (e.g., difficulty-based sampling)
- Enables curriculum-like or interleaved multi-task training
It assumes that each call to `read_async()` corresponds to one training step,
and batches are built by aggregating samples from different tasksets based on
a shuffled global schedule.
"""
[docs]
def __init__(self, explorer_state: Dict, config: Config):
"""
Initialize the scheduler from configuration and previous state (for resume support).
Args:
explorer_state (Dict): Restoration state from checkpoint (may include progress info)
config (Config): Full system configuration containing buffer and taskset settings
"""
self.config = config
# Backward compatibility: old format stored 'latest_task_index' directly
if "latest_task_index" in explorer_state:
assert len(config.buffer.explorer_input.tasksets) == 1 # old format
explorer_state["taskset_states"] = [
{
"current_index": explorer_state["latest_task_index"],
}
]
self.read_batch_size = config.buffer.batch_size
taskset_configs = config.buffer.explorer_input.tasksets
from trinity.buffer.reader.file_reader import TaskFileReader
taskset_states = explorer_state.get(
"taskset_states", [{"current_index": 0}] * len(taskset_configs)
)
self.tasksets = []
self.selectors = []
for taskset_config, taskset_state in zip(taskset_configs, taskset_states):
assert not taskset_config.is_eval # assume drop last
taskset = get_buffer_reader(taskset_config)
if not isinstance(taskset, TaskFileReader):
raise TypeError(
f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'."
f"Currently, only 'TaskFileReader' is supported by TasksetScheduler."
)
# Create selector based on type specified in config (e.g., 'sequential', 'shuffle')
selector = SELECTORS.get(taskset_config.task_selector.selector_type)(
taskset.dataset, taskset_config.task_selector
)
selector.load_state_dict(taskset_state) # Restore any prior state
self.tasksets.append(taskset)
self.selectors.append(selector)
# Each explorer step calls read_async once → track step globally
self.step = explorer_state.get("latest_iteration", 0)
# Build flat list indicating how often each taskset should appear per epoch
self.base_taskset_ids = []
for i, taskset in enumerate(self.tasksets):
self.base_taskset_ids.extend([i] * len(taskset))
if len(self.base_taskset_ids) == 0:
raise ValueError("Empty tasksets provided!")
self.epoch = self.step * self.read_batch_size // len(self.base_taskset_ids)
self.orders = self.build_orders(self.epoch)
[docs]
def build_orders(self, epoch: int):
"""
Creates a shuffled sequence of taskset IDs to control sampling priority per step.
At the start of each epoch, all tasksets are shuffled proportionally to their size,
ensuring balanced exposure while introducing randomness in selection order.
Args:
epoch (int): Epoch ID used as seed for deterministic shuffling
Returns:
List[int]: Sequence of taskset IDs, length = steps_per_epoch * batch_size
"""
taskset_ids = self.base_taskset_ids.copy()
rng = np.random.default_rng(epoch)
rng.shuffle(taskset_ids)
return taskset_ids
[docs]
async def read_async(self) -> List:
"""
Asynchronously reads a batch of tasks according to the current schedule.
For each step:
- Checks if a new epoch has started; rebuilds order if so
- Determines which tasksets contribute to this batch
- Uses each taskset's selector to pick specific samples
- Annotates each task with its source taskset_id
- Returns combined list of tasks
Raises:
StopAsyncIteration: When total_epochs is reached
Returns:
List[Task]: A batch of tasks from potentially multiple tasksets
"""
if self.config.buffer.total_steps:
if self.step >= self.config.buffer.total_steps:
raise StopAsyncIteration
else:
if self.epoch >= self.config.buffer.total_epochs:
raise StopAsyncIteration
batch_size = self.read_batch_size
start = self.step * batch_size % len(self.base_taskset_ids)
end = start + batch_size
if end <= len(self.base_taskset_ids):
taskset_ids = self.orders[start:end]
if end == len(self.base_taskset_ids):
self.epoch += 1
self.orders = self.build_orders(self.epoch)
else:
taskset_ids = self.orders[start:]
self.epoch += 1
if self.epoch >= self.config.buffer.total_epochs:
raise StopAsyncIteration
self.orders = self.build_orders(self.epoch)
taskset_ids += self.orders[: (end - len(self.base_taskset_ids))]
counter = Counter(taskset_ids)
batch = []
for taskset_id, count in counter.items():
indices = self.selectors[taskset_id].get_indices(batch_size=count)
tasks = await self.tasksets[taskset_id].read_with_indices_async(indices)
# Annotate each task with its origin
for task in tasks:
task.index["taskset_id"] = taskset_id
batch.extend(tasks)
self.step += 1
return batch
[docs]
def state_dict(self) -> List[Dict]:
"""
Save persistent state for checkpointing.
Returns:
List[Dict]: State dicts for all selectors (one per taskset)
"""
return [selector.state_dict() for selector in self.selectors]
[docs]
def update(self, pipeline_metrics: Dict) -> None:
"""
Update selectors using feedback from the training pipeline.
Expected format:
pipeline_metrics = {
SELECTOR_METRIC: {
0: {"indices": [...], "values": [...]},
1: {"indices": [...], "values": [...]}
},
... # other metrics
}
This allows adaptive selectors (like `DifficultyBasedSelector`) to refine difficulty estimates.
Args:
pipeline_metrics (Dict): Metrics dictionary passed from explorer.
"""
if SELECTOR_METRIC not in pipeline_metrics:
return
selector_metric = pipeline_metrics[SELECTOR_METRIC]
for taskset_id, taskset_kwargs in selector_metric.items():
selector = self.selectors[taskset_id]
selector.update(**taskset_kwargs)