Source code for trinity.common.models.vllm_worker

# -*- coding: utf-8 -*-
"""Custom vLLM Worker."""
import ray
import torch
import torch.distributed

from trinity.utils.distributed import init_process_group, is_ipv6_address
from trinity.utils.log import get_logger

logger = get_logger(__name__)


[docs] class WorkerExtension:
[docs] def init_process_group( self, master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, ): """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" self._update_with_checkpoint = update_with_checkpoint if self._update_with_checkpoint: logger.info( f"init_process_group (checkpoint): address={master_address}:{master_port}, rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" ) self._weight_update_rank = torch.distributed.get_rank() + rank_offset else: logger.info( f"init_process_group (nccl): rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" ) self._weight_update_rank = torch.distributed.get_rank() + rank_offset if is_ipv6_address(master_address): # using tcp://ipv6:port will lead to ValueError init_method = f"tcp://[{master_address}]:{master_port}" else: init_method = f"tcp://{master_address}:{master_port}" self._model_update_group = init_process_group( backend=backend, init_method=init_method, timeout=timeout, world_size=world_size, rank=self._weight_update_rank, group_name=group_name, ) logger.info( f"init_process_group: master_address={master_address}, master_port={master_port}, " f"rank={self._weight_update_rank}, world_size={world_size}, group_name={group_name}" ) self._explorer_actor = None
[docs] def update_weight(self, name, dtype, shape, empty_cache=False): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" if self._weight_update_rank == 0: if self._explorer_actor is None: self._explorer_actor = ray.get_actor(name="explorer") weight = ray.get(self._explorer_actor.get_weight.remote(name)) weight = weight.to(self.device) else: weight = torch.empty(shape, dtype=dtype, device="cuda") torch.distributed.broadcast(weight, 0, group=self._model_update_group) weight = weight.type(self.model_config.dtype) self.model_runner.model.load_weights(weights=[(name, weight)]) del weight