trinity.common.models
Submodules
trinity.common.models.model module
Base Model Class
- class trinity.common.models.model.InferenceModel[source]
Bases:
ABC
A model for high performance for rollout inference.
- generate(prompts: List[str], **kwargs) List[Experience] [source]
Generate a batch of responses from a batch of prompts.
- chat(messages: List[dict], **kwargs) List[Experience] [source]
Generate experiences from a list of history chat messages.
- convert_messages_to_experience(messages: List[dict]) Experience [source]
Convert a list of messages into an experience.
- async generate_async(prompt: str, **kwargs) List[Experience] [source]
Generate a responses from a prompt in async.
- async chat_async(messages: List[dict], **kwargs) List[Experience] [source]
Generate experiences from a list of history chat messages in async.
- async logprobs_async(tokens: List[int]) Tensor [source]
Generate logprobs for a list of tokens in async.
- async convert_messages_to_experience_async(messages: List[dict]) Experience [source]
Convert a list of messages into an experience in async.
- class trinity.common.models.model.ModelWrapper(model: Any, model_type: str = 'vllm')[source]
Bases:
object
A wrapper for the InferenceModel Ray Actor
- generate(prompts: List[str], **kwargs) List[Experience] [source]
- chat(messages: List[dict], **kwargs) List[Experience] [source]
- convert_messages_to_experience(messages: List[dict]) Experience [source]
Convert a list of messages into an experience.
- property model_version: int
Get the version of the model.
trinity.common.models.openai_api module
OpenAI API server related tools.
Modified from vllm/entrypoints/openai/api_server.py
trinity.common.models.utils module
- trinity.common.models.utils.tokenize_and_mask_messages_hf(tokenizer: Any, messages: List[dict], chat_template: str | None = None) Tuple[Tensor, Tensor] [source]
Calculate the assistant token mask with chat_template.
- Parameters:
tokenizer (Any) – The tokenizer.
chat_template (str) – The chat template with {% generation %} symbol.
messages (List[dict]) – Messages with role and content fields.
- Returns:
The token_ids (sequence_length) and assistant_masks (sequence_length).
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- trinity.common.models.utils.tokenize_and_mask_messages_default(tokenizer: Any, messages: List[dict], chat_template: str | None = None) Tuple[Tensor, Tensor] [source]
Calculate the assistant token mask.
- Parameters:
tokenizer (Any) – The tokenizer.
chat_template (str) – The chat template with {% generation %} symbol.
messages (List[dict]) – Messages with role and content fields.
- Returns:
The token_ids (sequence_length) and assistant_masks (sequence_length).
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Note
This method is based on the assumption that as the number of chat rounds increases, the tokens of the previous round are exactly the prefix tokens of the next round. If the assumption is not met, the function may produce incorrect results. Please check the chat template before using this method.
- trinity.common.models.utils.get_checkpoint_dir_with_step_num(checkpoint_root_path: str, trainer_type: str = 'verl', step_num: int | None = None) Tuple[str, int] [source]
Get the checkpoint directory from a root checkpoint directory.
- Parameters:
checkpoint_root_path (str) – The root checkpoint directory.
trainer_type (str) – The trainer type. Only support “verl” for now.
step_num (Optional[int], optional) – The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.
- Returns:
The checkpoint directory and the step number of the checkpoint.
- Return type:
Tuple[str, int]
- trinity.common.models.utils.load_state_dict(checkpoint_dir: str, trainer_type: str = 'verl') dict [source]
Load state dict from a checkpoint dir.
- Parameters:
checkpoint_dir (str) – The checkpoint directory.
trainer_type (str) – The trainer type. Only support “verl” for now.
- trinity.common.models.utils.merge_by_placement(tensors: List[Tensor], placement: Placement)[source]
- trinity.common.models.utils.get_verl_checkpoint_info(checkpoint_path: str, step_num: int | None = None) Tuple[str, int] [source]
Get the checkpoint directory from a Verl root checkpoint directory.
- Parameters:
checkpoint_path (str) – The root checkpoint directory.
step_num (Optional[int], optional) – The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.
- Returns:
The checkpoint directory and the step number of the checkpoint.
- Return type:
Tuple[str, int]
trinity.common.models.vllm_async_model module
vLLM AsyncEngine wrapper.
Modified from Ray python/ray/llm/_internal/batch/stages/vllm_engine_stage.py
- class trinity.common.models.vllm_async_model.vLLMAysncRolloutModel(config: InferenceModelConfig)[source]
Bases:
InferenceModel
Wrapper around the vLLM engine to handle async requests.
- Parameters:
config (Config) – The config.
kwargs (dict) – The keyword arguments for the engine.
- __init__(config: InferenceModelConfig) None [source]
- async chat_async(messages: List[Dict], **kwargs) List[Experience] [source]
Chat with the model with a list of messages in async.
- Parameters:
messages (List[dict]) – The input history messages.
kwargs (dict) – A dictionary of sampling parameters.
- Returns:
A list of experiences.
- async generate_async(prompt: str, **kwargs) List[Experience] [source]
Generate a response from the provided prompt in async.
- Parameters:
prompt (str) – The input prompt.
kwargs (dict) – A dictionary of sampling parameters.
- Returns:
A list of experiences.
- async logprobs_async(token_ids: List[int]) Tensor [source]
Calculate the logprobs of the given tokens in async.
- async convert_messages_to_experience_async(messages: List[dict]) Experience [source]
Convert a list of messages into an experience.
- shutdown()[source]
Shutdown the vLLM v1 engine. This kills child processes forked by the vLLM engine. If not called, the child processes will be orphaned and will not be killed when the parent process exits, and they won’t be able to be tracked by Ray anymore.
- async sync_model(model_version: int, update_weight_args_list: List[Tuple] | None = None) bool [source]
Sync model weights to vLLM.
- async init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, explorer_name: str, backend: str = 'nccl', timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: dict | None = None)[source]
- async run_api_server()[source]
Run the OpenAI API server in a Ray actor.
Note
Do not use ray.get() on this method. This method will run forever until the server is shut down.
trinity.common.models.vllm_model module
vLLM related modules.
Modified from OpenRLHF openrlhf/trainer/ray/vllm_engine.py
- class trinity.common.models.vllm_model.vLLMRolloutModel(config: InferenceModelConfig)[source]
Bases:
InferenceModel
Actor for vLLM.
- __init__(config: InferenceModelConfig)[source]
- init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, explorer_name: str, backend: str = 'nccl', timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: dict | None = None)[source]
- generate(prompts: List[str], **kwargs) List [source]
Generate a batch of responses from a batch of prompts.
Note
This method will not apply chat template. You need to apply chat template before calling this method.
- Parameters:
prompts (List[str]) – A list of prompts.
kwargs (dict) – A dictionary of sampling parameters.
- Returns:
A list of experiences.
- Return type:
List[Experience]
Example
>>> # config.algorithm.repeat_times == 2 or kwargs["n"] == 2 >>> >>> prompts = [ >>> "Hello, world!", >>> "How are you?" >>> ] >>> experiences = model.generate(prompts) >>> print(experiences) [ Experience(tokens=tensor()...), # first sequnece for prompts[0] Experience(tokens=tensor()...), # second sequnece for prompts[0] Experience(tokens=tensor()...), # first sequence for prompts[1] Experience(tokens=tensor()...) # second sequence for prompts[1] ]
- chat(messages: List[dict], **kwargs) List[Experience] [source]
Chat with the model with a list of messages.
- Parameters:
messages (List[dict]) – A list of messages.
Example
>>> [ >>> {"role": "system", "content": "You are a helpful assistant."}, >>> {"role": "user", "content": "Hello, world!"}, >>> ]
- Returns:
A list of experiences containing the response text.
- Return type:
List[Experience]
- convert_messages_to_experience(messages: List[dict]) Experience [source]
Convert a list of messages into an experience.
trinity.common.models.vllm_worker module
Custom vLLM Worker.
- class trinity.common.models.vllm_worker.WorkerExtension[source]
Bases:
object
- init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, backend: str = 'nccl', timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: list | None = None, explorer_name: str | None = None, namespace: str | None = None)[source]
Init torch process group for model weights update
Module contents
- trinity.common.models.create_inference_models(config: Config) Tuple[List[InferenceModel], List[List[InferenceModel]]] [source]
Create engine_num rollout models.
Each model has tensor_parallel_size workers.