Source code for trinity.common.workflows.simple_mm_workflow

from typing import List, Optional

import openai

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.reward_fn import RewardFn
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task


[docs] @WORKFLOWS.register_module("simple_mm_workflow") class SimpleMMWorkflow(SimpleWorkflow): """A workflow for simple single-round task."""
[docs] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.reset(task) super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, )
[docs] def reset(self, task: Task): self.format_args = task.format_args self.system_prompt = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: <think>\n ...your reasoning process here... </think>\n first. You should always include your final answer in \\boxed{} as closed-form results.""" # TODO: check self.reply_prefix = task.format_args.reply_prefix self.reward_fn_args = task.reward_fn_args self.raw_task = task.raw_task self.task_desc = task.task_desc assert task.raw_task is not None self.truth = task.raw_task[task.format_args.response_key] or task.truth reward_fn = task.reward_fn if isinstance(reward_fn, type) and issubclass(reward_fn, RewardFn): self.reward_fn: RewardFn = reward_fn(**self.reward_fn_args) else: raise ValueError("`reward_fn` must be a subclass of `RewardFn`") self.image_key = task.format_args.image_key self.video_key = task.format_args.video_key self.raw_mm_data = {} if self.image_key and task.raw_task.get(self.image_key) is not None: self.raw_mm_data["image"] = task.raw_task[self.image_key] if self.video_key and task.raw_task.get(self.video_key) is not None: self.raw_mm_data["video"] = task.raw_task[self.video_key]
[docs] def run(self) -> List[Experience]: messages = self.format_messages() # TODO: test generate_mm self.logger.debug("start chat") if self.raw_mm_data: responses = self.model.chat_mm(messages, self.raw_mm_data, **self.rollout_args) else: responses = self.model.chat(messages, **self.rollout_args) for i, response in enumerate(responses): reward_dict = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] truth=self.truth, ) if response.metrics is None: response.metrics = {} response.metrics.update(reward_dict) reward = sum(reward_dict.values()) response.reward = reward response.eid.run = i + self.run_id_base self.logger.debug(f"Generated {len(responses)} responses") return responses