Source code for trinity.common.workflows.eval_workflow
# -*- coding: utf-8 -*-
"""Evaluation Workflow Class"""
from dataclasses import asdict
from typing import List, Optional
from trinity.common.config import GenerationConfig
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.qwen25_eval import verify_math_answer
from trinity.common.workflows.workflow import Task, Workflow
[docs]
class MathEvalWorkflow(Workflow):
"""
A workflow for standard math evaluation.
The evaluation standard and prompting style are follow the Qwen2.5-Math
model's evaluation methodology. For more details on their approach, see:
https://github.com/QwenLM/Qwen2.5-Math
"""
[docs]
def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
super().__init__(
task=task,
model=model,
auxiliary_models=auxiliary_models,
)
self.raw_task = task.raw_task
self.truth = task.truth
# TODO: customize the config in the yaml
self.eval_gen_args = asdict(GenerationConfig(temperature=0.6, top_p=0.8, logprobs=0, n=1))
[docs]
def run(self) -> List[Experience]:
messages = self.format_messages()
responses: List[Experience] = self.model.chat(messages, **self.eval_gen_args)
for response in responses:
if response.response_text is None or self.task.truth is None:
continue
accuracy, _ = verify_math_answer(
response_text=response.response_text, ground_truth=self.task.truth
)
acc_metrics = {"accuracy": accuracy}
if response.metrics is None:
response.metrics = {}
response.metrics.update(acc_metrics)
return responses
[docs]
class AsyncMathEvalWorkflow(MathEvalWorkflow):
is_async: bool = True
[docs]
async def run_async(self) -> List[Experience]:
messages = self.format_messages()
responses: List[Experience] = await self.model.chat_async(messages, **self.eval_gen_args)
for response in responses:
if response.response_text is None or self.task.truth is None:
continue
accuracy, _ = verify_math_answer(
response_text=response.response_text, ground_truth=self.task.truth
)
acc_metrics = {"accuracy": accuracy}
if response.metrics is None:
response.metrics = {}
response.metrics.update(acc_metrics)
return responses