trinity.manager package
Submodules
Module contents
- class trinity.manager.CacheManager(config: Config, check_config: bool = False)[source]
Bases:
object
A Manager class for managing the cache dir.
- class trinity.manager.Synchronizer(config: Config, module_ref: ActorHandle)[source]
Bases:
object
A central component to manage synchronization of models and states between the trainer and one or more explorers in a distributed training setup.
- trainer_status
Current status of the trainer (e.g., running, waiting).
- explorer_status_counts
Dictionary tracking the number of explorers in each status.
- _ready_condition
Async condition variable for signaling state changes.
- model_state_dict
The latest model weights.
- model_version
Version number of the current model.
- checkpoint_shard_counter
Tracks how many shards are received from trainer for a specific train step.
- async add_module(module_ref: ActorHandle) None [source]
Adds a module to be tracked by the synchronizer.
- Parameters:
module_ref – The Ray actor handle of the module to track.
- async set_trainer_status(status: RunningStatus)[source]
Update the status of the trainer.
- get_trainer_status() RunningStatus [source]
Get the current status of the trainer.
- async set_explorer_status(status: RunningStatus, old_status: RunningStatus | None = None)[source]
Update the status count for an explorer.
- Parameters:
status – New status of the explorer.
old_status – Previous status if changing from one to another.
- get_explorer_status_counts() Dict[RunningStatus, int] [source]
Return the current status counts for all explorers.
- async set_model_state_dict_with_step_num(step_num: int | None = None, world_size: int | None = None) int [source]
Load and set the model state dictionary from a checkpoint at a specific step.
- Parameters:
step_num – Training step number corresponding to the checkpoint.
world_size – Number of shards expected for this checkpoint.
- Returns:
The updated model version (step number).
- async set_model_state_dict(model_state_dict: dict | None, trainer_step: int)[source]
Set the new model state and update the version.
- Parameters:
model_state_dict – The PyTorch model state dictionary.
trainer_step – Step number associated with this model version.
- get_state_dict_meta()[source]
Return metadata about the model state (names, data types, shapes).
- Returns:
(name, dtype, shape).
- Return type:
List of tuples
- async setup_weight_sync_group(master_address: str, master_port: int, state_dict_meta: List | None = None)[source]
Notify the explorer actor to setup weight sync group.
This is used to initialize NCCL-based synchronization for distributed training.
- Parameters:
master_address – IP address of the master node.
master_port – Port used for synchronization.
state_dict_meta – Metadata of the model parameters.
- async wait_new_model_state_dict(current_version: int, no_wait: bool = False) int [source]
Wait until a new model state is available.
- Parameters:
current_version – Current model version known to one explorer.
- Returns:
The new model version after it has been updated.
- async ready_to_nccl_sync(module: str, trainer_step: int | None = None) int | None [source]
Prepare for NCCL-based synchronization between modules.
Only supports one explorer currently.
- Parameters:
module – Either ‘trainer’ or ‘explorer’.
trainer_step – Optional step number from the trainer.
- Returns:
The model version if both sides are ready; otherwise None.
- classmethod get_actor(config: Config | None = None, namespace: str | None = None)[source]
Get or create a remote Ray actor for the Synchronizer.
- Parameters:
config – Optional configuration to use for creating the actor.
namespace – Optional Ray namespace for the actor.
- Returns:
A reference to the Synchronizer actor.