trinity.common.models.model module#

Base Model Class

class trinity.common.models.model.InferenceModel[源代码]#

基类:ABC

A model for high performance for rollout inference.

async generate(prompt: str, **kwargs) Sequence[Experience][源代码]#

Generate a responses from a prompt in async.

async chat(messages: List[dict], **kwargs) Sequence[Experience][源代码]#

Generate experiences from a list of history chat messages in async.

async logprobs(token_ids: List[int], **kwargs) Tensor[源代码]#

Generate logprobs for a list of tokens in async.

async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#

Convert a list of messages into an experience in async.

async prepare() None[源代码]#

Prepare the model before inference.

abstractmethod async sync_model(model_version: int) int[源代码]#

Sync the model with the latest model_version.

abstractmethod get_model_version() int[源代码]#

Get the checkpoint version.

get_available_address() Tuple[str, int][源代码]#

Get the address of the actor.

get_api_server_url() str | None[源代码]#

Get the API server URL if available.

get_model_path() str | None[源代码]#

Get the model path

class trinity.common.models.model.ModelWrapper(model: InferenceModel, engine_type: str = 'vllm', enable_lora: bool = False, enable_history: bool = False, enable_thinking: bool | None = None)[源代码]#

基类:object

A wrapper for the InferenceModel Ray Actor

__init__(model: InferenceModel, engine_type: str = 'vllm', enable_lora: bool = False, enable_history: bool = False, enable_thinking: bool | None = None)[源代码]#

Initialize the ModelWrapper.

参数:
  • model (InferenceModel) -- The inference model Ray actor.

  • engine_type (str) -- The type of the model engine. Default to "vllm".

  • enable_lora (bool) -- Whether to enable LoRA. Default to False.

  • enable_history (bool) -- Whether to enable history recording. Default to False.

  • enable_thinking (Optional[bool]) -- Whether to enable thinking mode. Default to None. Only used for Qwen3 series models.

async prepare() None[源代码]#

Prepare the model wrapper.

generate(*args, **kwargs)[源代码]#
async generate_async(*args, **kwargs)[源代码]#
generate_mm(*args, **kwargs)[源代码]#
async generate_mm_async(*args, **kwargs)[源代码]#
chat(*args, **kwargs)[源代码]#
async chat_async(*args, **kwargs)[源代码]#
chat_mm(*args, **kwargs)[源代码]#
async chat_mm_async(*args, **kwargs)[源代码]#
logprobs(tokens: List[int], temperature: float | None = None) Tensor[源代码]#

Calculate the logprobs of the given tokens.

async logprobs_async(tokens: List[int], temperature: float | None = None) Tensor[源代码]#

Calculate the logprobs of the given tokens in async.

convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#

Convert a list of messages into an experience.

async convert_messages_to_experience_async(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#

Convert a list of messages into an experience in async.

property model_version: int#

Get the version of the model.

property model_version_async: int#

Get the version of the model.

property model_path: str#

Get the model path.

property model_path_async: str#

Get the model path.

get_lora_request() Any[源代码]#
async get_lora_request_async() Any[源代码]#
async get_message_token_len(messages: List[dict]) int[源代码]#
get_openai_client() OpenAI[源代码]#

Get the openai client.

返回:

The openai client. And model_path is added to the client which refers to the model path.

返回类型:

openai.OpenAI

get_openai_async_client() AsyncOpenAI[源代码]#

Get the async openai client.

返回:

The async openai client. And model_path is added to the client which refers to the model path.

返回类型:

openai.AsyncOpenAI

async get_current_load() int[源代码]#

Get the current load metrics of the model.

async sync_model_weights(model_version: int) None[源代码]#

Sync the model weights

extract_experience_from_history(clear_history: bool = True) List[Experience][源代码]#

Extract experiences from the history.

async set_workflow_state(state: Dict) None[源代码]#

Set the state of workflow using the model.

async clean_workflow_state() None[源代码]#

Clean the state of workflow using the model.

async get_workflow_state() Dict[源代码]#

Get the state of workflow using the model.

trinity.common.models.model.convert_api_output_to_experience(output) List[Experience][源代码]#

Convert the API output to a list of experiences.

trinity.common.models.model.extract_logprobs(choice) Tensor[源代码]#

Extract logprobs from a list of logprob dictionaries.