trinity.common.models

Submodules

trinity.common.models.model module

Base Model Class

class trinity.common.models.model.InferenceModel[source]

Bases: ABC

A model for high performance for rollout inference.

generate(prompts: List[str], **kwargs) List[Experience][source]

Generate a batch of responses from a batch of prompts.

chat(messages: List[dict], **kwargs) List[Experience][source]

Generate experiences from a list of history chat messages.

logprobs(token_ids: List[int]) Tensor[source]

Generate logprobs for a list of tokens.

convert_messages_to_experience(messages: List[dict]) Experience[source]

Convert a list of messages into an experience.

async generate_async(prompt: str, **kwargs) List[Experience][source]

Generate a responses from a prompt in async.

async chat_async(messages: List[dict], **kwargs) List[Experience][source]

Generate experiences from a list of history chat messages in async.

async logprobs_async(tokens: List[int]) Tensor[source]

Generate logprobs for a list of tokens in async.

async convert_messages_to_experience_async(messages: List[dict]) Experience[source]

Convert a list of messages into an experience in async.

abstract get_model_version() int[source]

Get the checkpoint version.

get_available_address() Tuple[str, int][source]

Get the address of the actor.

class trinity.common.models.model.ModelWrapper(model: Any, model_type: str = 'vllm')[source]

Bases: object

A wrapper for the InferenceModel Ray Actor

__init__(model: Any, model_type: str = 'vllm')[source]
generate(prompts: List[str], **kwargs) List[Experience][source]
chat(messages: List[dict], **kwargs) List[Experience][source]
logprobs(tokens: List[int]) Tensor[source]
convert_messages_to_experience(messages: List[dict]) Experience[source]

Convert a list of messages into an experience.

property model_version: int

Get the version of the model.

get_openai_client() OpenAI[source]

Get the openai client.

Returns:

The openai client. And model_path is added to the client which refers to the model path.

Return type:

openai.OpenAI

trinity.common.models.openai_api module

OpenAI API server related tools.

Modified from vllm/entrypoints/openai/api_server.py

async trinity.common.models.openai_api.run_server_in_ray(args, engine_client)[source]
trinity.common.models.openai_api.dummy_add_signal_handler(self, *args, **kwargs)[source]
async trinity.common.models.openai_api.patch_and_serve_http(app, sock, args)[source]

Patch the add_signal_handler method and serve the app.

async trinity.common.models.openai_api.run_api_server_in_ray_actor(async_llm, host: str, port: int, model_path: str)[source]

trinity.common.models.utils module

trinity.common.models.utils.tokenize_and_mask_messages_hf(tokenizer: Any, messages: List[dict], chat_template: str | None = None) Tuple[Tensor, Tensor][source]

Calculate the assistant token mask with chat_template.

Parameters:
  • tokenizer (Any) – The tokenizer.

  • chat_template (str) – The chat template with {% generation %} symbol.

  • messages (List[dict]) – Messages with role and content fields.

Returns:

The token_ids (sequence_length) and assistant_masks (sequence_length).

Return type:

Tuple[torch.Tensor, torch.Tensor]

trinity.common.models.utils.tokenize_and_mask_messages_default(tokenizer: Any, messages: List[dict], chat_template: str | None = None) Tuple[Tensor, Tensor][source]

Calculate the assistant token mask.

Parameters:
  • tokenizer (Any) – The tokenizer.

  • chat_template (str) – The chat template with {% generation %} symbol.

  • messages (List[dict]) – Messages with role and content fields.

Returns:

The token_ids (sequence_length) and assistant_masks (sequence_length).

Return type:

Tuple[torch.Tensor, torch.Tensor]

Note

This method is based on the assumption that as the number of chat rounds increases, the tokens of the previous round are exactly the prefix tokens of the next round. If the assumption is not met, the function may produce incorrect results. Please check the chat template before using this method.

trinity.common.models.utils.get_checkpoint_dir_with_step_num(checkpoint_root_path: str, trainer_type: str = 'verl', step_num: int | None = None) Tuple[str, int][source]

Get the checkpoint directory from a root checkpoint directory.

Parameters:
  • checkpoint_root_path (str) – The root checkpoint directory.

  • trainer_type (str) – The trainer type. Only support “verl” for now.

  • step_num (Optional[int], optional) – The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.

Returns:

The checkpoint directory and the step number of the checkpoint.

Return type:

Tuple[str, int]

trinity.common.models.utils.load_state_dict(checkpoint_dir: str, trainer_type: str = 'verl') dict[source]

Load state dict from a checkpoint dir.

Parameters:
  • checkpoint_dir (str) – The checkpoint directory.

  • trainer_type (str) – The trainer type. Only support “verl” for now.

trinity.common.models.utils.merge_by_placement(tensors: List[Tensor], placement: Placement)[source]
trinity.common.models.utils.get_verl_checkpoint_info(checkpoint_path: str, step_num: int | None = None) Tuple[str, int][source]

Get the checkpoint directory from a Verl root checkpoint directory.

Parameters:
  • checkpoint_path (str) – The root checkpoint directory.

  • step_num (Optional[int], optional) – The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.

Returns:

The checkpoint directory and the step number of the checkpoint.

Return type:

Tuple[str, int]

trinity.common.models.utils.load_state_dict_from_verl_checkpoint(checkpoint_path: str) dict[source]

Load state dict from a Verl checkpoint.

trinity.common.models.vllm_async_model module

vLLM AsyncEngine wrapper.

Modified from Ray python/ray/llm/_internal/batch/stages/vllm_engine_stage.py

class trinity.common.models.vllm_async_model.vLLMAysncRolloutModel(config: InferenceModelConfig)[source]

Bases: InferenceModel

Wrapper around the vLLM engine to handle async requests.

Parameters:
  • config (Config) – The config.

  • kwargs (dict) – The keyword arguments for the engine.

__init__(config: InferenceModelConfig) None[source]
async chat_async(messages: List[Dict], **kwargs) List[Experience][source]

Chat with the model with a list of messages in async.

Parameters:
  • messages (List[dict]) – The input history messages.

  • kwargs (dict) – A dictionary of sampling parameters.

Returns:

A list of experiences.

async generate_async(prompt: str, **kwargs) List[Experience][source]

Generate a response from the provided prompt in async.

Parameters:
  • prompt (str) – The input prompt.

  • kwargs (dict) – A dictionary of sampling parameters.

Returns:

A list of experiences.

async logprobs_async(token_ids: List[int]) Tensor[source]

Calculate the logprobs of the given tokens in async.

async convert_messages_to_experience_async(messages: List[dict]) Experience[source]

Convert a list of messages into an experience.

shutdown()[source]

Shutdown the vLLM v1 engine. This kills child processes forked by the vLLM engine. If not called, the child processes will be orphaned and will not be killed when the parent process exits, and they won’t be able to be tracked by Ray anymore.

async sync_model(model_version: int, update_weight_args_list: List[Tuple] | None = None) bool[source]

Sync model weights to vLLM.

async init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, explorer_name: str, backend: str = 'nccl', timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: dict | None = None)[source]
async run_api_server()[source]

Run the OpenAI API server in a Ray actor.

Note

Do not use ray.get() on this method. This method will run forever until the server is shut down.

async has_api_server() bool[source]
async api_server_ready() Tuple[str | None, str | None][source]

Check if the OpenAI API server is ready.

Returns:

The URL of the OpenAI API server. model_path (str): The path of the model.

Return type:

api_url (str)

async reset_prefix_cache() None[source]
get_model_version() int[source]

Get the checkpoint version.

async sleep(level: int = 1) None[source]
async wake_up() None[source]

trinity.common.models.vllm_model module

vLLM related modules.

Modified from OpenRLHF openrlhf/trainer/ray/vllm_engine.py

class trinity.common.models.vllm_model.vLLMRolloutModel(config: InferenceModelConfig)[source]

Bases: InferenceModel

Actor for vLLM.

__init__(config: InferenceModelConfig)[source]
init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, explorer_name: str, backend: str = 'nccl', timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: dict | None = None)[source]
reset_prefix_cache()[source]
sleep(level=1)[source]
wake_up()[source]
generate(prompts: List[str], **kwargs) List[source]

Generate a batch of responses from a batch of prompts.

Note

This method will not apply chat template. You need to apply chat template before calling this method.

Parameters:
  • prompts (List[str]) – A list of prompts.

  • kwargs (dict) – A dictionary of sampling parameters.

Returns:

A list of experiences.

Return type:

List[Experience]

Example

>>> # config.algorithm.repeat_times == 2 or kwargs["n"] == 2
>>>
>>> prompts = [
>>>     "Hello, world!",
>>>     "How are you?"
>>> ]
>>> experiences = model.generate(prompts)
>>> print(experiences)
[
    Experience(tokens=tensor()...),  # first sequnece for prompts[0]
    Experience(tokens=tensor()...),  # second sequnece for prompts[0]
    Experience(tokens=tensor()...),  # first sequence for prompts[1]
    Experience(tokens=tensor()...)   # second sequence for prompts[1]
]
chat(messages: List[dict], **kwargs) List[Experience][source]

Chat with the model with a list of messages.

Parameters:

messages (List[dict]) – A list of messages.

Example

>>> [
>>>   {"role": "system", "content": "You are a helpful assistant."},
>>>   {"role": "user", "content": "Hello, world!"},
>>> ]
Returns:

A list of experiences containing the response text.

Return type:

List[Experience]

logprobs(token_ids: List[int]) Tensor[source]

Generate logprobs for a list of tokens.

convert_messages_to_experience(messages: List[dict]) Experience[source]

Convert a list of messages into an experience.

has_api_server() bool[source]
sync_model(model_version: int, update_weight_args_list: List[Tuple] | None = None) bool[source]

Sync model weights to vLLM.

get_model_version() int[source]

Get the checkpoint version.

trinity.common.models.vllm_worker module

Custom vLLM Worker.

class trinity.common.models.vllm_worker.WorkerExtension[source]

Bases: object

init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, backend: str = 'nccl', timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: list | None = None, explorer_name: str | None = None, namespace: str | None = None)[source]

Init torch process group for model weights update

set_state_dict_meta(state_dict_meta)[source]
update_weight()[source]

Broadcast weight to all vllm workers from source rank 0 (actor model)

Module contents

trinity.common.models.create_inference_models(config: Config) Tuple[List[InferenceModel], List[List[InferenceModel]]][source]

Create engine_num rollout models.

Each model has tensor_parallel_size workers.