Source code for trinity.utils.distributed

# -*- coding: utf-8 -*-
"""For distributed training with multiple process groups."""
import ipaddress
from datetime import timedelta
from typing import Any, Optional, Union

import torch
from torch.distributed.distributed_c10d import (
    Backend,
    PrefixStore,
    Store,
    _new_process_group_helper,
    _world,
    default_pg_timeout,
    rendezvous,
)


[docs] def is_ipv6_address(ip_str: str) -> bool: try: ip = ipaddress.ip_address(ip_str) return isinstance(ip, ipaddress.IPv6Address) except ValueError: return False
[docs] def init_process_group( backend: Union[str, Backend] = None, init_method: Optional[str] = None, timeout: Optional[float] = None, world_size: int = -1, rank: int = -1, store: Optional[Store] = None, group_name: Optional[str] = None, pg_options: Optional[Any] = None, ): assert (store is None) or (init_method is None), "Cannot specify both init_method and store." if store is not None: assert world_size > 0, "world_size must be positive if using store" assert rank >= 0, "rank must be non-negative if using store" elif init_method is None: init_method = "env://" if backend: backend = Backend(backend) else: backend = Backend("undefined") if timeout is None: timeout = default_pg_timeout else: timeout = timedelta(seconds=timeout) # backward compatible API if store is None: rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. store = PrefixStore(group_name, store) # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 # We need to determine the appropriate parameter name based on PyTorch version pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" pg, _ = _new_process_group_helper( world_size, rank, [], backend, store, group_name=group_name, **{pg_options_param_name: pg_options}, timeout=timeout, ) _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} return pg