Source code for trinity.common.rewards.math_reward

# -*- coding: utf-8 -*-
"""Math Reward Function Class."""
from typing import Optional

from trinity.common.rewards.accuracy_reward import AccuracyReward
from trinity.common.rewards.format_reward import FormatReward
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.eval_utils import (
    compute_score,
    simple_answer_parser,
    validate_think_pattern,
)
from trinity.utils.log import get_logger

logger = get_logger(__name__)


[docs] @REWARD_FUNCTIONS.register_module("math_reward") class MathRewardFn(RewardFn): """A reward function that rewards for math task.""" DEFAULT_FORMAT_PATTERN = r".*?<think>.*?</think>\s*<answer>.*?</answer>\s*$" DEFAULT_ANSWER_PARSER = simple_answer_parser
[docs] def __init__( self, answer_parser=DEFAULT_ANSWER_PARSER, pattern=DEFAULT_FORMAT_PATTERN, ) -> None: self.accuracy_reward = AccuracyReward(answer_parser) self.format_reward = FormatReward(pattern)
def __call__( # type: ignore self, response: str, prompt: Optional[str] = None, truth: Optional[str] = None, ) -> dict[str, float]: accuracy_score = self.accuracy_reward(response, prompt, truth) format_score = self.format_reward(response) return {**accuracy_score, **format_score}
[docs] @REWARD_FUNCTIONS.register_module("math_boxed_reward") class MathBoxedRewardFn(RewardFn): """A reward function that rewards for math task."""
[docs] def __init__( self, **kwargs, ) -> None: pass
def __call__( # type: ignore self, response: str, truth: Optional[str] = None, with_think: Optional[bool] = False, format_score_coef: Optional[float] = 0.1, **kwargs, ) -> dict[str, float]: accuracy_score = compute_score(response, truth) format_score = 0.0 if with_think and not validate_think_pattern(response): format_score = (format_score_coef or 0.1) * -1.0 return {"accuracy": accuracy_score, "format_score": format_score}