Source code for trinity.common.rewards.naive_dapo_score

"""
This file contaims the naive dapo reward function for math tasks.
Adapted from https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py
"""

import concurrent
import math
import os
import re
import resource

import sympy
from pylatexenc import latex2text
from sympy.parsing import sympy_parser

# Constants for normalization
SUBSTITUTIONS = [
    ("an ", ""),
    ("a ", ""),
    (".$", "$"),
    ("\\$", ""),
    (r"\ ", ""),
    (" ", ""),
    ("mbox", "text"),
    (",\\text{and}", ","),
    ("\\text{and}", ","),
    ("\\text{m}", "\\text{}"),
]

REMOVED_EXPRESSIONS = [
    "square",
    "ways",
    "integers",
    "dollars",
    "mph",
    "inches",
    "hours",
    "km",
    "units",
    "\\ldots",
    "sue",
    "points",
    "feet",
    "minutes",
    "digits",
    "cents",
    "degrees",
    "cm",
    "gm",
    "pounds",
    "meters",
    "meals",
    "edges",
    "students",
    "childrentickets",
    "multiples",
    "\\text{s}",
    "\\text{.}",
    "\\text{\ns}",
    "\\text{}^2",
    "\\text{}^3",
    "\\text{\n}",
    "\\text{}",
    r"\mathrm{th}",
    r"^\circ",
    r"^{\circ}",
    r"\;",
    r",\!",
    "{,}",
    '"',
    "\\dots",
]


[docs] def normalize_final_answer(final_answer: str) -> str: """Normalize a final answer to a quantitative reasoning question. Args: final_answer: The answer string to normalize Returns: Normalized answer string """ final_answer = final_answer.split("=")[-1] # Apply substitutions and removals for before, after in SUBSTITUTIONS: final_answer = final_answer.replace(before, after) for expr in REMOVED_EXPRESSIONS: final_answer = final_answer.replace(expr, "") # Extract and normalize LaTeX math final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) # Normalize shorthand TeX: # \fracab -> \frac{a}{b} # \frac{abc}{bef} -> \frac{abc}{bef} # \fracabc -> \frac{a}{b}c # \sqrta -> \sqrt{a} # \sqrtab -> sqrt{a}b final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) final_answer = final_answer.replace("$", "") # Normalize numbers if final_answer.replace(",", "").isdigit(): final_answer = final_answer.replace(",", "") return final_answer.strip()
# sympy might hang -- we don't care about trying to be lenient in these cases BAD_SUBSTRINGS = ["^{", "^("] BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"] TUPLE_CHARS = "()[]"
[docs] def timeout(timeout_seconds: int = 8): if os.name == "posix": import signal def decorator(func): def handler(signum, frame): raise TimeoutError("Operation timed out!") def wrapper(*args, **kwargs): old_handler = signal.getsignal(signal.SIGALRM) signal.signal(signal.SIGALRM, handler) signal.alarm(timeout_seconds) try: return func(*args, **kwargs) finally: signal.alarm(0) signal.signal(signal.SIGALRM, old_handler) return wrapper return decorator else: raise NotImplementedError(f"Unsupported OS: {os.name}")
def _sympy_parse(expr: str): """Parses an expression with sympy.""" py_expr = expr.replace("^", "**") return sympy_parser.parse_expr( py_expr, transformations=( sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,) ), ) def _parse_latex(expr: str) -> str: """Attempts to parse latex to an expression sympy can read.""" expr = expr.replace("\\tfrac", "\\frac") expr = expr.replace("\\dfrac", "\\frac") expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. expr = latex2text.LatexNodes2Text().latex_to_text(expr) # Replace the specific characters that this parser uses. expr = expr.replace("√", "sqrt") expr = expr.replace("π", "pi") expr = expr.replace("∞", "inf") expr = expr.replace("∪", "U") expr = expr.replace("·", "*") expr = expr.replace("×", "*") return expr.strip() def _is_float(num: str) -> bool: try: float(num) return True except ValueError: return False def _is_int(x: float) -> bool: try: return abs(x - int(round(x))) <= 1e-7 except Exception: return False def _is_frac(expr: str) -> bool: return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) def _str_is_int(x: str) -> bool: try: x = _strip_properly_formatted_commas(x) x = float(x) return abs(x - int(round(x))) <= 1e-7 # type: ignore except Exception: return False def _str_to_int(x: str) -> int: x = x.replace(",", "") x = float(x) return int(x) def _inject_implicit_mixed_number(step: str): """ Automatically make a mixed number evalable e.g. 7 3/4 => 7+3/4 """ p1 = re.compile("([0-9]) +([0-9])") step = p1.sub("\\1+\\2", step) # implicit mults return step def _strip_properly_formatted_commas(expr: str): # We want to be careful because we don't want to strip tuple commas p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)") while True: next_expr = p1.sub("\\1\\3\\4", expr) if next_expr == expr: break expr = next_expr return next_expr def _normalize(expr: str) -> str: """Normalize answer expressions.""" if expr is None: return None # Remove enclosing `\text{}`. m = re.search(r"^\\text\{(?P<text>.+?)\}$", expr) if m is not None: expr = m.group("text") expr = expr.replace("\\%", "%") expr = expr.replace("\\$", "$") expr = expr.replace("$", "") expr = expr.replace("%", "") expr = expr.replace(" or ", " , ") expr = expr.replace(" and ", " , ") expr = expr.replace("million", "*10^6") expr = expr.replace("billion", "*10^9") expr = expr.replace("trillion", "*10^12") for unit in [ "degree", "cm", "centimeter", "meter", "mile", "second", "minute", "hour", "day", "week", "month", "year", "foot", "feet", "inch", "yard", "liter", ]: expr = re.sub(rf"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) expr = re.sub("\\^ *\\\\circ", "", expr) if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": expr = expr[1:-1] expr = re.sub(",\\\\! *", "", expr) if _is_float(expr) and _is_int(float(expr)): expr = str(int(round(float(expr)))) if "\\" in expr: try: expr = _parse_latex(expr) except Exception: pass # edge case with mixed numbers and negative signs expr = re.sub("- *", "-", expr) expr = _inject_implicit_mixed_number(expr) # don't be case sensitive for text answers expr = expr.lower() if _str_is_int(expr): expr = str(_str_to_int(expr)) return expr
[docs] def count_unknown_letters_in_expr(expr: str): expr = expr.replace("sqrt", "") expr = expr.replace("frac", "") letters_in_expr = set([x for x in expr if x.isalpha()]) return len(letters_in_expr)
[docs] def should_allow_eval(expr: str): # we don't want to try parsing unknown text or functions of more than two variables if count_unknown_letters_in_expr(expr) > 2: return False for bad_string in BAD_SUBSTRINGS: if bad_string in expr: return False for bad_regex in BAD_REGEXES: if re.search(bad_regex, expr) is not None: return False return True
# @timeout(timeout_seconds=10)
[docs] def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): @timeout(timeout_seconds=10) def check_equal(): memory_size = 1024**3 resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size)) expr = f"({ground_truth_normalized})-({given_normalized})" if should_allow_eval(expr): sympy_diff = _sympy_parse(expr) simplified = sympy.simplify(sympy_diff) if simplified == 0: return True return False with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: future = executor.submit(check_equal) try: return future.result(timeout=10) except (concurrent.futures.TimeoutError, Exception): future.cancel() return False
[docs] def split_tuple(expr: str): """ Split the elements in a tuple/interval, while handling well-formatted commas in large numbers """ expr = _strip_properly_formatted_commas(expr) if len(expr) == 0: return [] if ( len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) ): elems = [elem.strip() for elem in expr[1:-1].split(",")] else: elems = [expr] return elems
[docs] def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: """ The answer will be considered correct if: (a) it normalizes to the same string as the ground truth answer OR (b) sympy can simplify the difference between the expressions to 0 """ if given_answer is None: return False from verl.utils.reward_score.prime_math import math_normalize ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) # be at least as lenient as mathd if ground_truth_normalized_mathd == given_answer_normalized_mathd: return True, given_answer_normalized_mathd ground_truth_normalized = _normalize(ground_truth) given_normalized = _normalize(given_answer) if ground_truth_normalized is None: return False, given_normalized if ground_truth_normalized == given_normalized: return True, given_normalized if len(given_normalized) == 0: return False, given_normalized ground_truth_elems = split_tuple(ground_truth_normalized) given_elems = split_tuple(given_normalized) if len(ground_truth_elems) > 1 and ( ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1] ): is_correct = False elif len(ground_truth_elems) != len(given_elems): is_correct = False else: for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): if _is_frac(ground_truth_elem) and _is_frac(given_elem): # if fractions aren't reduced, then shouldn't be marked as correct # so, we don't want to allow sympy.simplify in this case is_correct = ground_truth_elem == given_elem elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) is_correct = False else: is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) if not is_correct: break return is_correct, given_normalized
def _last_boxed_only_string(string): """Strictly extract content from \boxed{}.""" idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx left_brace_idx = None right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if left_brace_idx is None: left_brace_idx = i elif string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if left_brace_idx is None or right_brace_idx is None: return None return string[left_brace_idx + 1 : right_brace_idx].strip()
[docs] def match_answer(response): is_matched = False response = response.split("</think>")[-1] # Find boxed ans_boxed = _last_boxed_only_string(response) if ans_boxed: is_matched = True response = ans_boxed return is_matched, response
[docs] def compute_score(solution_str: str, ground_truth: str) -> float: """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions Args: solution_str: The solution string ground_truth: The ground truth answer extra_info: dict with additional info for the score computation Returns: Reward score (1.0 for correct, 0.0 for incorrect) """ from verl.utils.reward_score.prime_math.grader import math_equal # First assert intended generation and gt type model_output = str(solution_str) ground_truth = str(ground_truth) # Extract answer from generated output is_matched, extracted_model_output = match_answer(model_output) # TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score # Verify the solution, first check simple comparisons. correct, pred = grade_answer(extracted_model_output, ground_truth) if not correct: try: if "\\pi" in extracted_model_output or "\\pi" in ground_truth: equivs = [] for pi in [math.pi, 3.14]: equivs.append( math_equal(extracted_model_output, ground_truth, tiemout=True, pi=pi) ) correct = any(equivs) else: correct = math_equal(extracted_model_output, ground_truth, timeout=True) except Exception: correct = False return 1.0 if correct else 0.0