trinity.common.models.model module#

Base Model Class

class trinity.common.models.model.InferenceModel[source]#

Bases: ABC

A model for high performance for rollout inference.

async generate(prompt: str, **kwargs) Sequence[Experience][source]#

Generate a responses from a prompt in async.

async chat(messages: List[dict], **kwargs) Sequence[Experience][source]#

Generate experiences from a list of history chat messages in async.

async logprobs(tokens: List[int]) Tensor[source]#

Generate logprobs for a list of tokens in async.

async convert_messages_to_experience(messages: List[dict]) Experience[source]#

Convert a list of messages into an experience in async.

abstract get_model_version() int[source]#

Get the checkpoint version.

get_available_address() Tuple[str, int][source]#

Get the address of the actor.

class trinity.common.models.model.ModelWrapper(model: Any, model_type: str = 'vllm', enable_history: bool = False)[source]#

Bases: object

A wrapper for the InferenceModel Ray Actor

__init__(model: Any, model_type: str = 'vllm', enable_history: bool = False)[source]#
generate(*args, **kwargs)[source]#
async generate_async(*args, **kwargs)[source]#
generate_mm(*args, **kwargs)[source]#
async generate_mm_async(*args, **kwargs)[source]#
chat(*args, **kwargs)[source]#
async chat_async(*args, **kwargs)[source]#
chat_mm(*args, **kwargs)[source]#
async chat_mm_async(*args, **kwargs)[source]#
logprobs(tokens: List[int]) Tensor[source]#

Calculate the logprobs of the given tokens.

async logprobs_async(tokens: List[int]) Tensor[source]#

Calculate the logprobs of the given tokens in async.

convert_messages_to_experience(messages: List[dict]) Experience[source]#

Convert a list of messages into an experience.

async convert_messages_to_experience_async(messages: List[dict]) Experience[source]#

Convert a list of messages into an experience in async.

property model_version: int#

Get the version of the model.

get_openai_client() OpenAI[source]#

Get the openai client.

Returns:

The openai client. And model_path is added to the client which refers to the model path.

Return type:

openai.OpenAI

get_openai_async_client() AsyncOpenAI[source]#

Get the async openai client.

Returns:

The async openai client. And model_path is added to the client which refers to the model path.

Return type:

openai.AsyncOpenAI

extract_experience_from_history(clear_history: bool = True) List[Experience][source]#

Extract experiences from the history.

trinity.common.models.model.convert_api_output_to_experience(output) List[Experience][source]#

Convert the API output to a list of experiences.

trinity.common.models.model.extract_logprobs(choice) Tensor[source]#

Extract logprobs from a list of logprob dictionaries.