trinity.manager package#

Subpackages#

Submodules#

Module contents#

class trinity.manager.StateManager(path: str, trainer_name: str | None = None, explorer_name: str | None = None, config: Config | None = None, check_config: bool = False)[源代码]#

基类:object

A Manager class for managing the running state of Explorer and Trainer.

__init__(path: str, trainer_name: str | None = None, explorer_name: str | None = None, config: Config | None = None, check_config: bool = False)[源代码]#
load_explorer() dict[源代码]#
load_explorer_server_url() str | None[源代码]#
load_stage() dict[源代码]#
load_trainer() dict[源代码]#
save_explorer(current_step: int, taskset_states: List[Dict]) None[源代码]#
save_explorer_server_url(url: str) None[源代码]#
save_stage(current_stage: int) None[源代码]#
save_trainer(current_step: int, sample_strategy_state: dict) None[源代码]#
class trinity.manager.Synchronizer(config: Config, module_ref: ActorHandle)[源代码]#

基类:object

A central component to manage synchronization of models and states between the trainer and one or more explorers in a distributed training setup.

trainer_status#

Current status of the trainer (e.g., running, waiting).

explorer_status_counts#

Dictionary tracking the number of explorers in each status.

_ready_condition#

Async condition variable for signaling state changes.

model_state_dict#

The latest model weights.

model_version#

Version number of the current model.

checkpoint_shard_counter#

Tracks how many shards are received from trainer for a specific train step.

__init__(config: Config, module_ref: ActorHandle)[源代码]#
async add_module(module_ref: ActorHandle) None[源代码]#

Adds a module to be tracked by the synchronizer.

参数:

module_ref -- The Ray actor handle of the module to track.

classmethod get_actor(config: Config | None = None, namespace: str | None = None)[源代码]#

Get or create a remote Ray actor for the Synchronizer.

参数:
  • config -- Optional configuration to use for creating the actor.

  • namespace -- Optional Ray namespace for the actor.

返回:

A reference to the Synchronizer actor.

get_explorer_status_counts() Dict[RunningStatus, int][源代码]#

Return the current status counts for all explorers.

async get_latest_model_version() int[源代码]#

Get the latest model version available in the synchronizer.

返回:

The current model version.

get_model_state_dict()[源代码]#

Return the current model state and its version.

async get_state_dict_meta()[源代码]#

Return metadata about the model state (names, data types, shapes).

返回:

(name, dtype, shape).

返回类型:

List of tuples

get_trainer_status() RunningStatus[源代码]#

Get the current status of the trainer.

async ready_to_nccl_sync(module: str, trainer_step: int | None = None) int | None[源代码]#

Prepare for NCCL-based synchronization between modules.

Only supports one explorer currently.

参数:
  • module -- Either 'trainer' or 'explorer'.

  • trainer_step -- Optional step number from the trainer.

返回:

The model version if both sides are ready; otherwise None.

async set_explorer_status(status: RunningStatus, old_status: RunningStatus | None = None)[源代码]#

Update the status count for an explorer.

参数:
  • status -- New status of the explorer.

  • old_status -- Previous status if changing from one to another.

async set_model_state_dict(model_state_dict: dict | None | str | Tuple[str, str], trainer_step: int)[源代码]#

Set the new model state and update the version.

参数:
  • model_state_dict -- The PyTorch model state dictionary.

  • trainer_step -- Step number associated with this model version.

async set_model_state_dict_with_step_num(step_num: int | None = None, world_size: int | None = None) int[源代码]#

Load and set the model state dictionary from a checkpoint at a specific step.

参数:
  • step_num -- Training step number corresponding to the checkpoint.

  • world_size -- Number of shards expected for this checkpoint.

返回:

The updated model version (step number).

async set_trainer_status(status: RunningStatus)[源代码]#

Update the status of the trainer.

async setup_weight_sync_group(master_address: str, master_port: int, state_dict_meta: List = None)[源代码]#

Notify the explorer actor to setup weight sync group.

This is used to initialize NCCL-based synchronization for distributed training.

参数:
  • master_address -- IP address of the master node.

  • master_port -- Port used for synchronization.

  • state_dict_meta -- Metadata of the model parameters.

async wait_new_model_state_dict(current_version: int, no_wait: bool = False) int[源代码]#

Wait until a new model state is available.

参数:

current_version -- Current model version known to one explorer.

返回:

The new model version after it has been updated.