# -*- coding: utf-8 -*-
"""For distributed training with multiple process groups."""
import ipaddress
from datetime import timedelta
from typing import Any, Optional, Union
import torch
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
[docs]
def is_ipv6_address(ip_str: str) -> bool:
try:
ip = ipaddress.ip_address(ip_str)
return isinstance(ip, ipaddress.IPv6Address)
except ValueError:
return False
[docs]
def init_process_group(
backend: Union[str, Backend] = None,
init_method: Optional[str] = None,
timeout: Optional[float] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Store] = None,
group_name: Optional[str] = None,
pg_options: Optional[Any] = None,
):
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
else:
timeout = timedelta(seconds=timeout)
# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg