trinity.manager.synchronizer module#
A centralized synchronizer for coordinating explorer and trainer.
- class trinity.manager.synchronizer.Synchronizer(config: Config, module_ref: ActorHandle)[源代码]#
基类:
objectA central component to manage synchronization of models and states between the trainer and one or more explorers in a distributed training setup.
- trainer_status#
Current status of the trainer (e.g., running, waiting).
- explorer_status_counts#
Dictionary tracking the number of explorers in each status.
- _ready_condition#
Async condition variable for signaling state changes.
- model_state_dict#
The latest model weights.
- model_version#
Version number of the current model.
- checkpoint_shard_counter#
Tracks how many shards are received from trainer for a specific train step.
- async add_module(module_ref: ActorHandle) None[源代码]#
Adds a module to be tracked by the synchronizer.
- 参数:
module_ref -- The Ray actor handle of the module to track.
- async set_trainer_status(status: RunningStatus)[源代码]#
Update the status of the trainer.
- get_trainer_status() RunningStatus[源代码]#
Get the current status of the trainer.
- async set_explorer_status(status: RunningStatus, old_status: RunningStatus | None = None)[源代码]#
Update the status count for an explorer.
- 参数:
status -- New status of the explorer.
old_status -- Previous status if changing from one to another.
- get_explorer_status_counts() Dict[RunningStatus, int][源代码]#
Return the current status counts for all explorers.
- async set_model_state_dict_with_step_num(step_num: int | None = None, world_size: int | None = None) int[源代码]#
Load and set the model state dictionary from a checkpoint at a specific step.
- 参数:
step_num -- Training step number corresponding to the checkpoint.
world_size -- Number of shards expected for this checkpoint.
- 返回:
The updated model version (step number).
- async set_model_state_dict(model_state_dict: dict | None | str | Tuple[str, str], trainer_step: int)[源代码]#
Set the new model state and update the version.
- 参数:
model_state_dict -- The PyTorch model state dictionary.
trainer_step -- Step number associated with this model version.
- async get_state_dict_meta()[源代码]#
Return metadata about the model state (names, data types, shapes).
- 返回:
(name, dtype, shape).
- 返回类型:
List of tuples
- async setup_weight_sync_group(master_address: str, master_port: int, state_dict_meta: List = None)[源代码]#
Notify the explorer actor to setup weight sync group.
This is used to initialize NCCL-based synchronization for distributed training.
- 参数:
master_address -- IP address of the master node.
master_port -- Port used for synchronization.
state_dict_meta -- Metadata of the model parameters.
- async wait_new_model_state_dict(current_version: int, no_wait: bool = False) int[源代码]#
Wait until a new model state is available.
- 参数:
current_version -- Current model version known to one explorer.
- 返回:
The new model version after it has been updated.
- async get_latest_model_version() int[源代码]#
Get the latest model version available in the synchronizer.
- 返回:
The current model version.
- async ready_to_nccl_sync(module: str, trainer_step: int | None = None) int | None[源代码]#
Prepare for NCCL-based synchronization between modules.
Only supports one explorer currently.
- 参数:
module -- Either 'trainer' or 'explorer'.
trainer_step -- Optional step number from the trainer.
- 返回:
The model version if both sides are ready; otherwise None.
- classmethod get_actor(config: Config | None = None, namespace: str | None = None)[源代码]#
Get or create a remote Ray actor for the Synchronizer.
- 参数:
config -- Optional configuration to use for creating the actor.
namespace -- Optional Ray namespace for the actor.
- 返回:
A reference to the Synchronizer actor.