Source code for trinity.common.rewards.format_reward

"""Base Reward Function Class."""

import re
from typing import Optional

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.log import get_logger

logger = get_logger(__name__)


[docs] @REWARD_FUNCTIONS.register_module("format_reward") class FormatReward(RewardFn): """A reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags. Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py """
[docs] def __init__(self, pattern: Optional[str] = None): self.pattern = pattern if pattern else r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
def __call__( # type: ignore self, response, ) -> dict[str, float]: if re.match(self.pattern, response, re.DOTALL | re.MULTILINE): return {"format_score": 0.1} else: return {"format_score": -0.1}