trinity.explorer.scheduler module#

Scheduler for rollout tasks.

class trinity.explorer.scheduler.TaskWrapper(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, results: ~typing.List[~typing.Tuple[~trinity.explorer.workflow_runner.Status, ~typing.List[~trinity.common.experience.Experience]]] = <factory>)[source]#

Bases: object

A wrapper for a task. Each task can run multiple times (repeat_times) on same or different runners.

task: Task#
batch_id: int | str#
sub_task_num: int = 1#
results: List[Tuple[Status, List[Experience]]]#
__init__(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, results: ~typing.List[~typing.Tuple[~trinity.explorer.workflow_runner.Status, ~typing.List[~trinity.common.experience.Experience]]] = <factory>) None#
trinity.explorer.scheduler.calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) Dict[str, float][source]#

Calculate task level metrics (mean) from multiple runs of the same task.

Parameters:
  • metrics (List[Dict]) – A list of metric dictionaries from multiple runs of the same task.

  • is_eval (bool) – Whether this is an evaluation task.

Returns:

A dictionary of aggregated metrics, where each metric is averaged over all runs.

Return type:

Dict[str, float]

class trinity.explorer.scheduler.RunnerWrapper(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[source]#

Bases: object

A wrapper for a WorkflowRunner

__init__(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[source]#
async prepare()[source]#
async update_state() None[source]#

Get the runner state.

async run_with_retry(task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float) Tuple[Status, List, int, float][source]#
Parameters:
  • task (TaskWrapper) – The task to run.

  • repeat_times (int) – The number of times to repeat the task.

  • run_id_base (int) – The base run id for this task runs.

  • timeout (float) – The timeout for each task run.

Returns:

The return status of the task. List: The experiences generated by the task. int: The runner_id of current runner. float: The time taken to run the task.

Return type:

Status

async restart_runner()[source]#
trinity.explorer.scheduler.sort_batch_id(batch_id: int | str)[source]#

Priority of batch_id

class trinity.explorer.scheduler.Scheduler(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[source]#

Bases: object

Scheduler for rollout tasks.

Supports scheduling tasks to multiple runners, retrying failed tasks, and collecting results at different levels.

__init__(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[source]#
task_done_callback(async_task: Task)[source]#
async start() None[source]#
async stop() None[source]#
schedule(tasks: List[Task], batch_id: int | str) None[source]#

Schedule the provided tasks.

Parameters:
  • tasks (List[Task]) – The tasks to schedule.

  • batch_id (Union[int, str]) – The id of provided tasks. It should be an integer or a string starting with an integer (e.g., 123, “123/my_task”)

dynamic_timeout(timeout: float | None = None) float[source]#

Calculate dynamic timeout based on historical data.

async get_results(batch_id: int | str, min_num: int | None = None, timeout: float | None = None, clear_timeout_tasks: bool = True) Tuple[List[Status], List[Experience]][source]#

Get the result of tasks at the specific batch_id.

Parameters:
  • batch_id (Union[int, str]) – Only wait for tasks at this batch.

  • min_num (int) – The minimum number of tasks to wait for. If None, wait for all tasks at batch_id.

  • timeout (float) – The timeout for waiting for tasks to finish. If None, wait for default timeout.

  • clear_timeout_tasks (bool) – Whether to clear timeout tasks.

has_step(batch_id: int | str) bool[source]#
async wait_all(timeout: float | None = None, clear_timeout_tasks: bool = True) None[source]#

Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError.

Parameters:
  • timeout (float) – timeout in seconds. Raise TimeoutError when no new tasks is completed within timeout.

  • clear_timeout_tasks (bool) – Whether to clear timeout tasks.

get_key_state(key: str) Dict[source]#

Get the scheduler state.

Parameters:

key (str) – The key of the state to get.

Returns:

A dictionary of runner ids to their state for the given key.

Return type:

Dict

get_runner_state(runner_id: int) Dict[source]#

Get the scheduler state.

Parameters:

runner_id (int) – The id of the runner.

Returns:

The state of the runner.

Return type:

Dict

get_all_state() Dict[source]#

Get all runners’ state.

Returns:

The state of all runners.

Return type:

Dict

print_all_state() None[source]#

Print all runners’ state in a clear, aligned table format.