trinity.common.models.model module

Base Model Class

class trinity.common.models.model.InferenceModel[source]

Bases: ABC

A model for high performance for rollout inference.

async generate(prompt: str, **kwargs) Sequence[Experience][source]

Generate a responses from a prompt in async.

async chat(messages: List[dict], **kwargs) Sequence[Experience][source]

Generate experiences from a list of history chat messages in async.

async logprobs(tokens: List[int]) Tensor[source]

Generate logprobs for a list of tokens in async.

async convert_messages_to_experience(messages: List[dict]) Experience[source]

Convert a list of messages into an experience in async.

abstract get_model_version() int[source]

Get the checkpoint version.

get_available_address() Tuple[str, int][source]

Get the address of the actor.

class trinity.common.models.model.ModelWrapper(model: Any, model_type: str = 'vllm', enable_history: bool = False)[source]

Bases: object

A wrapper for the InferenceModel Ray Actor

__init__(model: Any, model_type: str = 'vllm', enable_history: bool = False)[source]
generate(*args, **kwargs)[source]
async generate_async(*args, **kwargs)[source]
chat(*args, **kwargs)[source]
async chat_async(*args, **kwargs)[source]
logprobs(tokens: List[int]) Tensor[source]

Calculate the logprobs of the given tokens.

async logprobs_async(tokens: List[int]) Tensor[source]

Calculate the logprobs of the given tokens in async.

convert_messages_to_experience(messages: List[dict]) Experience[source]

Convert a list of messages into an experience.

async convert_messages_to_experience_async(messages: List[dict]) Experience[source]

Convert a list of messages into an experience in async.

property model_version: int

Get the version of the model.

get_openai_client() OpenAI[source]

Get the openai client.

Returns:

The openai client. And model_path is added to the client which refers to the model path.

Return type:

openai.OpenAI

extract_experience_from_history(clear_history: bool = True) List[Experience][source]

Extract experiences from the history.

trinity.common.models.model.convert_api_output_to_experience(output) List[Experience][source]

Convert the API output to a list of experiences.

trinity.common.models.model.extract_logprobs(choice) Tensor[source]

Extract logprobs from a list of logprob dictionaries.