trinity.common.models
Submodules
trinity.common.models.model module
Base Model Class
- class trinity.common.models.model.InferenceModel[source]
Bases:
ABC
A model for high performance for rollout inference.
- async generate(prompt: str, **kwargs) Sequence[Experience] [source]
Generate a responses from a prompt in async.
- async chat(messages: List[dict], **kwargs) Sequence[Experience] [source]
Generate experiences from a list of history chat messages in async.
- async convert_messages_to_experience(messages: List[dict]) Experience [source]
Convert a list of messages into an experience in async.
- class trinity.common.models.model.ModelWrapper(model: Any, model_type: str = 'vllm', enable_history: bool = False)[source]
Bases:
object
A wrapper for the InferenceModel Ray Actor
- async logprobs_async(tokens: List[int]) Tensor [source]
Calculate the logprobs of the given tokens in async.
- convert_messages_to_experience(messages: List[dict]) Experience [source]
Convert a list of messages into an experience.
- async convert_messages_to_experience_async(messages: List[dict]) Experience [source]
Convert a list of messages into an experience in async.
- property model_version: int
Get the version of the model.
- get_openai_client() OpenAI [source]
Get the openai client.
- Returns:
The openai client. And model_path is added to the client which refers to the model path.
- Return type:
openai.OpenAI
- extract_experience_from_history(clear_history: bool = True) List[Experience] [source]
Extract experiences from the history.
- trinity.common.models.model.convert_api_output_to_experience(output) List[Experience] [source]
Convert the API output to a list of experiences.
trinity.common.models.utils module
- trinity.common.models.utils.tokenize_and_mask_messages_hf(tokenizer: Any, messages: List[dict], chat_template: str | None = None) Tuple[Tensor, Tensor, int] [source]
Calculate the assistant token mask with chat_template.
- Parameters:
tokenizer (Any) – The tokenizer.
chat_template (str) – The chat template with {% generation %} symbol.
messages (List[dict]) – Messages with role and content fields.
- Returns:
The token_ids (sequence_length) torch.Tensor: Assistant_masks (sequence_length). int: Prompt length.
- Return type:
torch.Tensor
- trinity.common.models.utils.tokenize_and_mask_messages_default(tokenizer: Any, messages: List[dict], chat_template: str | None = None) Tuple[Tensor, Tensor, int] [source]
Calculate the assistant token mask.
- Parameters:
tokenizer (Any) – The tokenizer.
chat_template (str) – The chat template with {% generation %} symbol.
messages (List[dict]) – Messages with role and content fields.
- Returns:
The token_ids (sequence_length) torch.Tensor: Assistant_masks (sequence_length). int: Prompt length.
- Return type:
torch.Tensor
Note
This method is based on the assumption that as the number of chat rounds increases, the tokens of the previous round are exactly the prefix tokens of the next round. If the assumption is not met, the function may produce incorrect results. Please check the chat template before using this method.
- trinity.common.models.utils.get_checkpoint_dir_with_step_num(checkpoint_root_path: str, trainer_type: str = 'verl', step_num: int | None = None) Tuple[str, int] [source]
Get the checkpoint directory from a root checkpoint directory.
- Parameters:
checkpoint_root_path (str) – The root checkpoint directory.
trainer_type (str) – The trainer type. Only support “verl” for now.
step_num (Optional[int], optional) – The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.
- Returns:
The checkpoint directory and the step number of the checkpoint.
- Return type:
Tuple[str, int]
- trinity.common.models.utils.load_state_dict(checkpoint_dir: str, trainer_type: str = 'verl') dict [source]
Load state dict from a checkpoint dir.
- Parameters:
checkpoint_dir (str) – The checkpoint directory.
trainer_type (str) – The trainer type. Only support “verl” for now.
- trinity.common.models.utils.merge_by_placement(tensors: List[Tensor], placement: Placement)[source]
- trinity.common.models.utils.get_verl_checkpoint_info(checkpoint_path: str, step_num: int | None = None) Tuple[str, int] [source]
Get the checkpoint directory from a Verl root checkpoint directory.
- Parameters:
checkpoint_path (str) – The root checkpoint directory.
step_num (Optional[int], optional) – The step number. If specified, load the checkpoint with the specified step number. If None, load the latest checkpoint. Defaults to None.
- Returns:
The checkpoint directory and the step number of the checkpoint.
- Return type:
Tuple[str, int]
trinity.common.models.vllm_model module
A wrapper around the vllm.AsyncEngine to handle async requests.
- class trinity.common.models.vllm_model.vLLMRolloutModel(config: InferenceModelConfig)[source]
Bases:
InferenceModel
Wrapper around the vLLM engine to handle async requests.
- Parameters:
config (Config) – The config.
- __init__(config: InferenceModelConfig) None [source]
- async chat(messages: List[Dict], **kwargs) Sequence[Experience] [source]
Chat with the model with a list of messages in async.
- Parameters:
messages (List[dict]) – The input history messages.
kwargs (dict) – A dictionary of sampling parameters.
- Returns:
A list of experiences.
- async generate(prompt: str, **kwargs) Sequence[Experience] [source]
Generate a response from the provided prompt in async.
- Parameters:
prompt (str) – The input prompt.
kwargs (dict) – A dictionary of sampling parameters.
- Returns:
A list of experiences.
- async logprobs(token_ids: List[int]) Tensor [source]
Calculate the logprobs of the given tokens in async. Please slice the result carefully to align with the actual response length.
- Parameters:
token_ids (List[int]) – The input token ids (seq_length).
- Returns:
A tensor of logprobs (seq_length - 1).
- async convert_messages_to_experience(messages: List[dict]) Experience [source]
Convert a list of messages into an experience.
- shutdown()[source]
Shutdown the vLLM v1 engine. This kills child processes forked by the vLLM engine. If not called, the child processes will be orphaned and will not be killed when the parent process exits, and they won’t be able to be tracked by Ray anymore.
- async init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, explorer_name: str, backend: str = 'nccl', timeout: int = 1200, state_dict_meta: dict | None = None)[source]
- async run_api_server()[source]
Run the OpenAI API server in a Ray actor.
Note
Do not use ray.get() on this method. This method will run forever until the server is shut down.
trinity.common.models.vllm_worker module
Custom vLLM Worker.
- class trinity.common.models.vllm_worker.WorkerExtension[source]
Bases:
object
- init_process_group(master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, backend: str = 'nccl', timeout: int = 1200, state_dict_meta: list | None = None, explorer_name: str | None = None, namespace: str | None = None)[source]
Init torch process group for model weights update
Module contents
- trinity.common.models.create_inference_models(config: Config) Tuple[List[InferenceModel], List[List[InferenceModel]]] [source]
Create engine_num rollout models.
Each model has tensor_parallel_size workers.