trinity.common.models.vllm_model module

A wrapper around the vllm.AsyncEngine to handle async requests.

class trinity.common.models.vllm_model.vLLMRolloutModel(config: InferenceModelConfig)[source]

Bases: InferenceModel

Wrapper around the vLLM engine to handle async requests.

Parameters:

config (Config) – The config.

__init__(config: InferenceModelConfig) None[source]
async chat(messages: List[Dict], **kwargs) Sequence[Experience][source]

Chat with the model with a list of messages in async.

Parameters:
  • messages (List[dict]) – The input history messages.

  • kwargs (dict) – A dictionary of sampling parameters.

Returns:

A list of experiences.

async generate(prompt: str, **kwargs) Sequence[Experience][source]

Generate a response from the provided prompt in async.

Parameters:
  • prompt (str) – The input prompt.

  • kwargs (dict) – A dictionary of sampling parameters.

Returns:

A list of experiences.

async logprobs(token_ids: List[int]) Tensor[source]

Calculate the logprobs of the given tokens in async. Please slice the result carefully to align with the actual response length.

Parameters:

token_ids (List[int]) – The input token ids (seq_length).

Returns:

A tensor of logprobs (seq_length - 1).

async convert_messages_to_experience(messages: List[dict]) Experience[source]

Convert a list of messages into an experience.

shutdown()[source]

Shutdown the vLLM v1 engine. This kills child processes forked by the vLLM engine. If not called, the child processes will be orphaned and will not be killed when the parent process exits, and they won’t be able to be tracked by Ray anymore.

async sync_model(model_version: int) bool[source]

Sync model weights to vLLM.

async init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, explorer_name: str, backend: str = 'nccl', timeout: int = 1200, state_dict_meta: dict | None = None)[source]
async run_api_server()[source]

Run the OpenAI API server in a Ray actor.

Note

Do not use ray.get() on this method. This method will run forever until the server is shut down.

async has_api_server() bool[source]
async api_server_ready() str | None[source]

Check if the OpenAI API server is ready.

Returns:

The URL of the OpenAI API server.

Return type:

api_url (str)

async reset_prefix_cache() None[source]
get_model_version() int[source]

Get the checkpoint version.

async sleep(level: int = 1) None[source]
async wake_up() None[source]