Source code for trinity.common.models.vllm_model

# -*- coding: utf-8 -*-
"""vLLM related modules.

Modified from OpenRLHF openrlhf/trainer/ray/vllm_engine.py
"""
from __future__ import annotations

import os
import re
import threading
from typing import List

import torch
import vllm
from vllm import LLM
from vllm.sampling_params import SamplingParams

from trinity.common.config import InferenceModelConfig
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
from trinity.common.models.utils import (
    tokenize_and_mask_messages_default,
    tokenize_and_mask_messages_hf,
)
from trinity.utils.log import get_logger


[docs] class vLLMRolloutModel(InferenceModel): """Actor for vLLM."""
[docs] def __init__(self, config: InferenceModelConfig): self.logger = get_logger(__name__) self.config = config if config.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices if not vllm.envs.is_set("VLLM_USE_V1"): self.logger.info(f"Using vLLM v{int(config.use_v1)} engine") os.environ["VLLM_USE_V1"] = str(int(config.use_v1)) if config.use_v1: os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" self.default_sampling_params = SamplingParams( n=1, temperature=0.0, max_tokens=config.max_response_tokens, min_tokens=1, truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, logprobs=0, ) max_model_len = None if config.max_prompt_tokens is not None and config.max_response_tokens is not None: max_model_len = config.max_prompt_tokens + config.max_response_tokens self.llm = LLM( # TODO: check checkpoint path model=config.model_path, enforce_eager=config.enforce_eager, worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", tensor_parallel_size=config.tensor_parallel_size, seed=config.seed, distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"), max_model_len=max_model_len, enable_prefix_caching=config.enable_prefix_caching, dtype=config.dtype, trust_remote_code=True, gpu_memory_utilization=config.gpu_memory_utilization, enable_chunked_prefill=config.enable_chunked_prefill, # max_num_batched_tokens=256, ) self.tokenizer = self.llm.get_tokenizer() self.chat_template = self.tokenizer.get_chat_template() self.enable_thinking = config.enable_thinking if self.config.chat_template: self.chat_template = self.config.chat_template if not re.search(r"\{\%-?\s*generation\s*-?\%\}", self.chat_template): self.logger.warning( "The provided chat template does not support `return_assitant_tokens_mask`. " "The default assistant mask method will be used, which may cause performance " "degradation and lead to incorrect results." ) self.action_mask_method = tokenize_and_mask_messages_default else: self.action_mask_method = tokenize_and_mask_messages_hf self.lock = threading.Lock() self.ckp_version = 0 # TODO: resume the value from the checkpoint
[docs] def init_process_group( self, master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, ): return self.llm.collective_rpc( "init_process_group", args=( master_address, master_port, rank_offset, world_size, group_name, backend, timeout, update_with_checkpoint, ), )
[docs] def update_weight(self, name, dtype, shape, empty_cache=False): return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))
[docs] def reset_prefix_cache(self): self.llm.llm_engine.reset_prefix_cache()
[docs] def sleep(self, level=1): self.llm.sleep(level=level)
[docs] def wake_up(self): self.llm.wake_up()
def _create_sampling_params(self, **kwargs): """Create sampling params.""" if len(kwargs) == 0: return self.default_sampling_params params = self.default_sampling_params.clone() for k, v in kwargs.items(): if hasattr(params, k): setattr(params, k, v) return params
[docs] def generate(self, prompts: List[str], **kwargs) -> List: """ Generate a batch of responses from a batch of prompts. Note: This method will not apply chat template. You need to apply chat template before calling this method. Args: prompts (List[str]): A list of prompts. kwargs (dict): A dictionary of sampling parameters. Returns: List[Experience]: A list of experiences. Example: >>> # config.algorithm.repeat_times == 2 or kwargs["n"] == 2 >>> >>> prompts = [ >>> "Hello, world!", >>> "How are you?" >>> ] >>> experiences = model.generate(prompts) >>> print(experiences) [ Experience(tokens=tensor()...), # first sequnece for prompts[0] Experience(tokens=tensor()...), # second sequnece for prompts[0] Experience(tokens=tensor()...), # first sequence for prompts[1] Experience(tokens=tensor()...) # second sequence for prompts[1] ] """ with self.lock: sampling_params = self._create_sampling_params(**kwargs) outputs = self.llm.generate(prompts, sampling_params, use_tqdm=False) experiences = [] for output in outputs: for i in range(sampling_params.n): experiences.append( Experience( tokens=torch.cat( ( torch.tensor(output.prompt_token_ids, dtype=torch.int32), torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), ) ), logprobs=torch.cat( ( torch.full( (len(output.prompt_token_ids),), 0.0, dtype=torch.float32, ), torch.tensor( [ list(logprob_dict.values())[0].logprob for logprob_dict in output.outputs[i].logprobs ], dtype=torch.float32, ), ) ), prompt_length=len(output.prompt_token_ids), prompt_text=output.prompt, response_text=output.outputs[i].text, ) ) return experiences
[docs] def chat(self, messages: List[dict], **kwargs) -> List[Experience]: """Chat with the model with a list of messages. Args: messages (List[dict]): A list of messages. Example: >>> [ >>> {"role": "system", "content": "You are a helpful assistant."}, >>> {"role": "user", "content": "Hello, world!"}, >>> ] Returns: List[Experience]: A list of experiences containing the response text. """ # TODO: support tools and other fields if messages[-1]["role"] == "assistant": prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, continue_final_message=True, chat_template=self.chat_template, ) else: prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, chat_template=self.chat_template, enable_thinking=self.enable_thinking, ) return self.generate([prompt], **kwargs)
[docs] def logprobs(self, token_ids: List[int]) -> torch.Tensor: with self.lock: outputs = self.llm.generate( prompts={"prompt_token_ids": token_ids}, sampling_params=self._create_sampling_params( n=1, max_tokens=1, prompt_logprobs=0, ), use_tqdm=False, ) return torch.tensor( [0] + [ list(logprob_dict.values())[0].logprob for logprob_dict in outputs[0].prompt_logprobs[1:] ], dtype=torch.float32, )
[docs] def convert_messages_to_experience(self, messages: List[dict]) -> Experience: """Convert a list of messages into an experience.""" token_ids, action_mask = self.action_mask_method( self.tokenizer, messages, self.chat_template ) logprobs = self.logprobs(token_ids=token_ids.tolist()) return Experience( tokens=token_ids, prompt_length=len(token_ids), logprobs=logprobs, action_mask=action_mask, )
[docs] def has_api_server(self) -> bool: return False
[docs] def sync_model(self, update_weight_args_list) -> bool: """Sync model weights to vLLM.""" with self.lock: for args in update_weight_args_list: self.llm.collective_rpc("update_weight", args=args) self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 return True
[docs] def get_ckp_version(self) -> int: return self.ckp_version