Source code for trinity.common.workflows.eval_workflow

# -*- coding: utf-8 -*-
"""Evaluation Workflow Class"""

from dataclasses import asdict
from typing import List, Optional

import openai

from trinity.common.config import GenerationConfig
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
from trinity.utils.log import get_logger
from trinity.utils.math_eval_utils import verify_math_answer

logger = get_logger(__name__)


[docs] @WORKFLOWS.register_module("math_eval_workflow") class MathEvalWorkflow(Workflow): """ A workflow for standard math evaluation. The evaluation standard and prompting style are follow the Qwen2.5-Math model's evaluation methodology. For more details on their approach, see: https://github.com/QwenLM/Qwen2.5-Math """
[docs] def __init__( self, *, task: Task, model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): super().__init__( task=task, model=model, auxiliary_models=auxiliary_models, ) self.raw_task = task.raw_task self.truth = task.truth # TODO: customize the config in the yaml self.eval_gen_args = asdict(GenerationConfig(temperature=0.6, top_p=0.8, logprobs=0, n=1))
@property def resettable(self): return False @property def repeatable(self): return False
[docs] def format_messages(self): """Format message for the evaluation of qwen_boxed type.""" if not self.raw_task or "question" not in self.raw_task: raise ValueError("Raw task data must contain a 'question' field for MathEvalWorkflow.") problem_input = self.raw_task["question"] system_prompt = "You are a helpful assistant." user_prompt = f"{problem_input}\nPlease reason step by step, and put your final answer within \\boxed{{}}." messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] return messages
[docs] def run(self) -> List[Experience]: messages = self.format_messages() responses: List[Experience] = self.model.chat(messages, **self.eval_gen_args) for response in responses: if response.response_text is None or self.task.truth is None: continue accuracy, _ = verify_math_answer( response_text=response.response_text, ground_truth=self.task.truth ) acc_metrics = {"accuracy": accuracy} if response.metrics is None: response.metrics = {} response.metrics.update(acc_metrics) return responses