trinity.explorer.scheduler module#
Scheduler for rollout tasks.
- class trinity.explorer.scheduler.TaskWrapper(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, results: ~typing.List[~typing.Tuple[~trinity.explorer.workflow_runner.Status, ~typing.List[~trinity.common.experience.Experience]]] = <factory>)[source]#
Bases:
objectA wrapper for a task. Each task can run multiple times (repeat_times) on same or different runners.
- batch_id: int | str#
- sub_task_num: int = 1#
- results: List[Tuple[Status, List[Experience]]]#
- __init__(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, results: ~typing.List[~typing.Tuple[~trinity.explorer.workflow_runner.Status, ~typing.List[~trinity.common.experience.Experience]]] = <factory>) None#
- trinity.explorer.scheduler.calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) Dict[str, float][source]#
Calculate task level metrics (mean) from multiple runs of the same task.
- Parameters:
metrics (List[Dict]) – A list of metric dictionaries from multiple runs of the same task.
is_eval (bool) – Whether this is an evaluation task.
- Returns:
A dictionary of aggregated metrics, where each metric is averaged over all runs.
- Return type:
Dict[str, float]
- class trinity.explorer.scheduler.RunnerWrapper(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[source]#
Bases:
objectA wrapper for a WorkflowRunner
- __init__(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[source]#
- async run_with_retry(task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float) Tuple[Status, List, int, float][source]#
- Parameters:
task (TaskWrapper) – The task to run.
repeat_times (int) – The number of times to repeat the task.
run_id_base (int) – The base run id for this task runs.
timeout (float) – The timeout for each task run.
- Returns:
The return status of the task. List: The experiences generated by the task. int: The runner_id of current runner. float: The time taken to run the task.
- Return type:
Status
- class trinity.explorer.scheduler.Scheduler(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[source]#
Bases:
objectScheduler for rollout tasks.
Supports scheduling tasks to multiple runners, retrying failed tasks, and collecting results at different levels.
- __init__(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[source]#
- schedule(tasks: List[Task], batch_id: int | str) None[source]#
Schedule the provided tasks.
- Parameters:
tasks (List[Task]) – The tasks to schedule.
batch_id (Union[int, str]) – The id of provided tasks. It should be an integer or a string starting with an integer (e.g., 123, “123/my_task”)
- dynamic_timeout(timeout: float | None = None) float[source]#
Calculate dynamic timeout based on historical data.
- async get_results(batch_id: int | str, min_num: int | None = None, timeout: float | None = None, clear_timeout_tasks: bool = True) Tuple[List[Status], List[Experience]][source]#
Get the result of tasks at the specific batch_id.
- Parameters:
batch_id (Union[int, str]) – Only wait for tasks at this batch.
min_num (int) – The minimum number of tasks to wait for. If None, wait for all tasks at batch_id.
timeout (float) – The timeout for waiting for tasks to finish. If None, wait for default timeout.
clear_timeout_tasks (bool) – Whether to clear timeout tasks.
- async wait_all(timeout: float | None = None, clear_timeout_tasks: bool = True) None[source]#
Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError.
- Parameters:
timeout (float) – timeout in seconds. Raise TimeoutError when no new tasks is completed within timeout.
clear_timeout_tasks (bool) – Whether to clear timeout tasks.
- get_key_state(key: str) Dict[source]#
Get the scheduler state.
- Parameters:
key (str) – The key of the state to get.
- Returns:
A dictionary of runner ids to their state for the given key.
- Return type:
Dict
- get_runner_state(runner_id: int) Dict[source]#
Get the scheduler state.
- Parameters:
runner_id (int) – The id of the runner.
- Returns:
The state of the runner.
- Return type:
Dict