# -*- coding: utf-8 -*-
"""For distributed training with multiple process groups."""
import ipaddress
from datetime import timedelta
from typing import Any, Optional, Union
import torch
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
[docs]
def is_ipv6_address(ip_str: str) -> bool:
try:
ip = ipaddress.ip_address(ip_str)
return isinstance(ip, ipaddress.IPv6Address)
except ValueError:
return False
[docs]
def init_process_group(
host: str,
port: int,
group_name: str,
backend: Union[str, Backend] = "nccl",
timeout: Optional[float] = None,
world_size: int = -1,
rank: int = -1,
pg_options: Optional[Any] = None,
):
assert backend == "nccl", "Only nccl backend is supported for now."
from torch.distributed.distributed_c10d import is_nccl_available
assert is_nccl_available()
init_method = (
f"tcp://[{host}]:{port}" if is_ipv6_address(ip_str=host) else f"tcp://{host}:{port}"
)
backend = Backend(backend)
if timeout is None:
timeout = default_pg_timeout
else:
timeout = timedelta(seconds=timeout)
# backward compatible API
store, rank, world_size = next(rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(group_name, store)
pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
pg, _ = _new_process_group_helper(
group_size=world_size,
group_rank=rank,
global_ranks_in_group=[],
backend=backend,
store=prefix_store,
group_name=group_name,
timeout=timeout,
**{pg_options_param_name: pg_options},
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg