# -*- coding: utf-8 -*-
"""Evaluation Workflow Class"""
from dataclasses import asdict
from typing import List, Optional
import openai
from trinity.common.config import GenerationConfig
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
from trinity.utils.log import get_logger
from trinity.utils.math_eval_utils import verify_math_answer
logger = get_logger(__name__)
[docs]
@WORKFLOWS.register_module("math_eval_workflow")
class MathEvalWorkflow(Workflow):
"""
A workflow for standard math evaluation.
The evaluation standard and prompting style are follow the Qwen2.5-Math
model's evaluation methodology. For more details on their approach, see:
https://github.com/QwenLM/Qwen2.5-Math
"""
[docs]
def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
super().__init__(
task=task,
model=model,
auxiliary_models=auxiliary_models,
)
self.raw_task = task.raw_task
self.truth = task.truth
# TODO: customize the config in the yaml
self.eval_gen_args = asdict(GenerationConfig(temperature=0.6, top_p=0.8, logprobs=0, n=1))
@property
def resettable(self):
return False
@property
def repeatable(self):
return False
[docs]
def run(self) -> List[Experience]:
messages = self.format_messages()
responses: List[Experience] = self.model.chat(messages, **self.eval_gen_args)
for response in responses:
if response.response_text is None or self.task.truth is None:
continue
accuracy, _ = verify_math_answer(
response_text=response.response_text, ground_truth=self.task.truth
)
acc_metrics = {"accuracy": accuracy}
if response.metrics is None:
response.metrics = {}
response.metrics.update(acc_metrics)
return responses