trinity.explorer.scheduler module#

Scheduler for rollout tasks.

class trinity.explorer.scheduler.TaskWrapper(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, results: ~typing.List[~typing.Tuple[~trinity.explorer.workflow_runner.Status, ~typing.List[~trinity.common.experience.Experience]]] = <factory>)[源代码]#

基类:object

A wrapper for a task. Each task can run multiple times (repeat_times) on same or different runners.

task: Task#
batch_id: int | str#
sub_task_num: int = 1#
results: List[Tuple[Status, List[Experience]]]#
__init__(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, results: ~typing.List[~typing.Tuple[~trinity.explorer.workflow_runner.Status, ~typing.List[~trinity.common.experience.Experience]]] = <factory>) None#
trinity.explorer.scheduler.calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) Dict[str, float][源代码]#

Calculate task level metrics (mean) from multiple runs of the same task.

参数:
  • metrics (List[Dict]) -- A list of metric dictionaries from multiple runs of the same task.

  • is_eval (bool) -- Whether this is an evaluation task.

返回:

A dictionary of aggregated metrics, where each metric is averaged over all runs.

返回类型:

Dict[str, float]

class trinity.explorer.scheduler.RunnerWrapper(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[源代码]#

基类:object

A wrapper for a WorkflowRunner

__init__(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[源代码]#
async prepare()[源代码]#
async update_state() None[源代码]#

Get the runner state.

async run_with_retry(task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float) Tuple[Status, List, int, float][源代码]#
参数:
  • task (TaskWrapper) -- The task to run.

  • repeat_times (int) -- The number of times to repeat the task.

  • run_id_base (int) -- The base run id for this task runs.

  • timeout (float) -- The timeout for each task run.

返回:

The return status of the task. List: The experiences generated by the task. int: The runner_id of current runner. float: The time taken to run the task.

返回类型:

Status

async restart_runner()[源代码]#
trinity.explorer.scheduler.sort_batch_id(batch_id: int | str)[源代码]#

Priority of batch_id

class trinity.explorer.scheduler.Scheduler(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[源代码]#

基类:object

Scheduler for rollout tasks.

Supports scheduling tasks to multiple runners, retrying failed tasks, and collecting results at different levels.

__init__(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[源代码]#
task_done_callback(async_task: Task)[源代码]#
async start() None[源代码]#
async stop() None[源代码]#
schedule(tasks: List[Task], batch_id: int | str) None[源代码]#

Schedule the provided tasks.

参数:
  • tasks (List[Task]) -- The tasks to schedule.

  • batch_id (Union[int, str]) -- The id of provided tasks. It should be an integer or a string starting with an integer (e.g., 123, "123/my_task")

dynamic_timeout(timeout: float | None = None) float[源代码]#

Calculate dynamic timeout based on historical data.

async get_results(batch_id: int | str, min_num: int | None = None, timeout: float | None = None, clear_timeout_tasks: bool = True) Tuple[List[Status], List[Experience]][源代码]#

Get the result of tasks at the specific batch_id.

参数:
  • batch_id (Union[int, str]) -- Only wait for tasks at this batch.

  • min_num (int) -- The minimum number of tasks to wait for. If None, wait for all tasks at batch_id.

  • timeout (float) -- The timeout for waiting for tasks to finish. If None, wait for default timeout.

  • clear_timeout_tasks (bool) -- Whether to clear timeout tasks.

has_step(batch_id: int | str) bool[源代码]#
async wait_all(timeout: float | None = None, clear_timeout_tasks: bool = True) None[源代码]#

Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError.

参数:
  • timeout (float) -- timeout in seconds. Raise TimeoutError when no new tasks is completed within timeout.

  • clear_timeout_tasks (bool) -- Whether to clear timeout tasks.

get_key_state(key: str) Dict[源代码]#

Get the scheduler state.

参数:

key (str) -- The key of the state to get.

返回:

A dictionary of runner ids to their state for the given key.

返回类型:

Dict

get_runner_state(runner_id: int) Dict[源代码]#

Get the scheduler state.

参数:

runner_id (int) -- The id of the runner.

返回:

The state of the runner.

返回类型:

Dict

get_all_state() Dict[源代码]#

Get all runners' state.

返回:

The state of all runners.

返回类型:

Dict

print_all_state() None[源代码]#

Print all runners' state in a clear, aligned table format.