trinity.explorer

Submodules

trinity.explorer.explorer module

The explorer module

class trinity.explorer.explorer.Explorer(config: Config)[source]

Bases: object

Responsible for exploring the taskset.

__init__(config: Config)[source]
async setup_weight_sync_group(master_address: str, master_port: int, state_dict_meta: List | None = None)[source]
async prepare() None[source]

Preparation before running.

async get_weight(name: str) Tensor[source]

Get the weight of the loaded model (For checkpoint weights update).

async explore() str[source]
The timeline of the exploration process:
<——————————— one period ————————————-> |
explorer | <—————- step_1 ————–> | |
| <—————- step_2 ————–> | |
… |
| <—————- step_n —————> | |
| <———————- eval ——————–> | <– sync –> |

|--------------------------------------------------------------------------------------|

trainer | <– idle –> | <– step_1 –> | <– step_2 –> | … | <– step_n –> | <– sync –> |

async explore_step() bool[source]
async need_sync() bool[source]
need_eval() bool[source]
async eval()[source]

Evaluation on all evaluation data samples.

async benchmark() bool[source]

Benchmark the model checkpoints.

async save_checkpoint(sync_weight: bool = False) None[source]
async sync_weight() None[source]

Synchronize model weights.

async shutdown() None[source]
is_alive() bool[source]

Check if the explorer is alive.

trinity.explorer.scheduler module

Scheduler for rollout tasks.

class trinity.explorer.scheduler.TaskWrapper(task: Task, batch_id: int | str, run_id_base: int = 0, repeat_times: int = 1)[source]

Bases: object

A wrapper for a task.

task: Task
batch_id: int | str
run_id_base: int = 0
repeat_times: int = 1
__init__(task: Task, batch_id: int | str, run_id_base: int = 0, repeat_times: int = 1) None
class trinity.explorer.scheduler.RunnerWrapper(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[source]

Bases: object

A wrapper for a WorkflowRunner

__init__(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[source]
async run_with_retry(task: TaskWrapper) Tuple[Status, List, int][source]
Returns:

The return status of the task. List: The experiences generated by the task. int: The runner_id of current runner.

Return type:

Status

restart_runner()[source]
trinity.explorer.scheduler.sort_batch_id(batch_id: int | str)[source]

Priority of batch_id

class trinity.explorer.scheduler.Scheduler(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[source]

Bases: object

Scheduler for rollout tasks.

__init__(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[source]
task_done_callback(async_task: Task)[source]
async start() None[source]
async stop() None[source]
schedule(tasks: List[Task], batch_id: int | str) None[source]

Schedule the provided tasks.

Parameters:
  • tasks (List[Task]) – The tasks to schedule.

  • batch_id (Union[int, str]) – The id of provided tasks. It should be an integer or a string starting with an integer (e.g., 123, “123/my_task”)

async get_results(batch_id: int | str, min_num: int | None = None, timeout: float | None = None, clear_timeout_tasks: bool = True) Tuple[List[Status], List[Experience]][source]

Get the result of tasks at the specific batch_id.

Parameters:
  • batch_id (Union[int, str]) – Only wait for tasks at this batch.

  • min_num (int) – The minimum number of tasks to wait for. If None, wait for all tasks at batch_id.

  • timeout (float) – The timeout for waiting for tasks to finish. If None, wait for default timeout.

  • clear_timeout_tasks (bool) – Whether to clear timeout tasks.

has_step(batch_id: int | str) bool[source]
async wait_all(timeout: float | None = None, clear_timeout_tasks: bool = True) None[source]

Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError.

Parameters:
  • timeout (float) – timeout in seconds. Raise TimeoutError when no new tasks is completed within timeout.

  • clear_timeout_tasks (bool) – Whether to clear timeout tasks.

trinity.explorer.workflow_runner module

The Workflow Runner Module.

class trinity.explorer.workflow_runner.Status(ok: bool, metric: dict[str, float], message: str | None = None)[source]

Bases: object

Status of the task running result.

ok: bool
metric: dict[str, float]
message: str | None = None
__init__(ok: bool, metric: dict[str, float], message: str | None = None) None
class trinity.explorer.workflow_runner.WorkflowRunner(config: Config, model: InferenceModel, auxiliary_models: List[InferenceModel] | None = None, runner_id: int | None = None)[source]

Bases: object

A Ray remote actor to run the workflow and generate experiences.

__init__(config: Config, model: InferenceModel, auxiliary_models: List[InferenceModel] | None = None, runner_id: int | None = None) None[source]
is_alive()[source]
run_task(task: Task, repeat_times: int = 1, run_id_base: int = 0) Tuple[Status, List[Experience]][source]

Run the task and return the states.

Module contents

class trinity.explorer.Explorer(config: Config)[source]

Bases: object

Responsible for exploring the taskset.

__init__(config: Config)[source]
async setup_weight_sync_group(master_address: str, master_port: int, state_dict_meta: List | None = None)[source]
async prepare() None[source]

Preparation before running.

async get_weight(name: str) Tensor[source]

Get the weight of the loaded model (For checkpoint weights update).

async explore() str[source]
The timeline of the exploration process:
<——————————— one period ————————————-> |
explorer | <—————- step_1 ————–> | |
| <—————- step_2 ————–> | |
… |
| <—————- step_n —————> | |
| <———————- eval ——————–> | <– sync –> |

|--------------------------------------------------------------------------------------|

trainer | <– idle –> | <– step_1 –> | <– step_2 –> | … | <– step_n –> | <– sync –> |

async explore_step() bool[source]
async need_sync() bool[source]
need_eval() bool[source]
async eval()[source]

Evaluation on all evaluation data samples.

async benchmark() bool[source]

Benchmark the model checkpoints.

async save_checkpoint(sync_weight: bool = False) None[source]
async sync_weight() None[source]

Synchronize model weights.

async shutdown() None[source]
is_alive() bool[source]

Check if the explorer is alive.