Source code for trinity.common.rewards.accuracy_reward

# -*- coding: utf-8 -*-
"""Accuracy Reward Function Class."""
from typing import Callable, Optional

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.log import get_logger

logger = get_logger(__name__)


[docs] @REWARD_FUNCTIONS.register_module("accuracy_reward") class AccuracyReward(RewardFn): """A reward function that rewards correct answers. Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py """
[docs] def __init__(self, answer_parser: Optional[Callable[[str], str]] = None): self.answer_parser = answer_parser
def __call__( # type: ignore self, response: str, prompt: Optional[str] = None, truth: Optional[str] = None, ) -> dict[str, float]: if self.answer_parser: answer_parsed = self.answer_parser(response) truth_parsed = self.answer_parser(truth) # type: ignore [arg-type] else: truth_parsed = parse( truth, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) if len(truth_parsed) == 0: truth_parsed = truth answer_parsed = parse( response, extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( nits=False, malformed_operators=False, basic_latex=True, equations=True, boxed="all", units=True, ), # Ensures that boxed is tried first boxed_match_priority=0, try_extract_without_anchor=False, ) ], extraction_mode="first_match", ) # Reward 1 if the content is the same as the ground truth, 0 otherwise try: reward = float(verify(answer_parsed, truth_parsed)) except Exception as e: logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}") reward = 0.0 return {"accuracy": reward}