Source code for trinity.common.workflows.math_rm_workflow

# -*- coding: utf-8 -*-
"""We include the math workflow with rm-gallery reward in this file."""

from typing import List, Optional

import openai

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
from trinity.utils.log import get_logger

logger = get_logger(__name__)


[docs] @WORKFLOWS.register_module("math_rm_workflow") class MathRMWorkflow(SimpleWorkflow): """A workflow for math tasks as introduced in DeepSeek-R1."""
[docs] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.reset(task) super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, )
[docs] def run(self) -> List[Experience]: messages = self.format_messages() logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) for i, response in enumerate(responses): reward_dict = self.reward_fn( # type: ignore response, messages, ground_truth=self.truth, ) if response.metrics is None: response.metrics = {} response.metrics.update(reward_dict) reward = sum(reward_dict.values()) response.reward = reward response.eid.run = i + self.run_id_base logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) return responses