"""
This file contaims the naive dapo reward function for math tasks.
Adapted from https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py
"""
import concurrent
import math
import os
import re
import resource
import sympy
from pylatexenc import latex2text
from sympy.parsing import sympy_parser
# Constants for normalization
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
[docs]
def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question.
Args:
final_answer: The answer string to normalize
Returns:
Normalized answer string
"""
final_answer = final_answer.split("=")[-1]
# Apply substitutions and removals
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract and normalize LaTeX math
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize numbers
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
return final_answer.strip()
# sympy might hang -- we don't care about trying to be lenient in these cases
BAD_SUBSTRINGS = ["^{", "^("]
BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"]
TUPLE_CHARS = "()[]"
[docs]
def timeout(timeout_seconds: int = 8):
if os.name == "posix":
import signal
def decorator(func):
def handler(signum, frame):
raise TimeoutError("Operation timed out!")
def wrapper(*args, **kwargs):
old_handler = signal.getsignal(signal.SIGALRM)
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout_seconds)
try:
return func(*args, **kwargs)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
return wrapper
return decorator
else:
raise NotImplementedError(f"Unsupported OS: {os.name}")
def _sympy_parse(expr: str):
"""Parses an expression with sympy."""
py_expr = expr.replace("^", "**")
return sympy_parser.parse_expr(
py_expr,
transformations=(
sympy_parser.standard_transformations
+ (sympy_parser.implicit_multiplication_application,)
),
)
def _parse_latex(expr: str) -> str:
"""Attempts to parse latex to an expression sympy can read."""
expr = expr.replace("\\tfrac", "\\frac")
expr = expr.replace("\\dfrac", "\\frac")
expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
expr = latex2text.LatexNodes2Text().latex_to_text(expr)
# Replace the specific characters that this parser uses.
expr = expr.replace("√", "sqrt")
expr = expr.replace("π", "pi")
expr = expr.replace("∞", "inf")
expr = expr.replace("∪", "U")
expr = expr.replace("·", "*")
expr = expr.replace("×", "*")
return expr.strip()
def _is_float(num: str) -> bool:
try:
float(num)
return True
except ValueError:
return False
def _is_int(x: float) -> bool:
try:
return abs(x - int(round(x))) <= 1e-7
except Exception:
return False
def _is_frac(expr: str) -> bool:
return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
def _str_is_int(x: str) -> bool:
try:
x = _strip_properly_formatted_commas(x)
x = float(x)
return abs(x - int(round(x))) <= 1e-7 # type: ignore
except Exception:
return False
def _str_to_int(x: str) -> int:
x = x.replace(",", "")
x = float(x)
return int(x)
def _inject_implicit_mixed_number(step: str):
"""
Automatically make a mixed number evalable
e.g. 7 3/4 => 7+3/4
"""
p1 = re.compile("([0-9]) +([0-9])")
step = p1.sub("\\1+\\2", step) # implicit mults
return step
def _strip_properly_formatted_commas(expr: str):
# We want to be careful because we don't want to strip tuple commas
p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)")
while True:
next_expr = p1.sub("\\1\\3\\4", expr)
if next_expr == expr:
break
expr = next_expr
return next_expr
def _normalize(expr: str) -> str:
"""Normalize answer expressions."""
if expr is None:
return None
# Remove enclosing `\text{}`.
m = re.search(r"^\\text\{(?P<text>.+?)\}$", expr)
if m is not None:
expr = m.group("text")
expr = expr.replace("\\%", "%")
expr = expr.replace("\\$", "$")
expr = expr.replace("$", "")
expr = expr.replace("%", "")
expr = expr.replace(" or ", " , ")
expr = expr.replace(" and ", " , ")
expr = expr.replace("million", "*10^6")
expr = expr.replace("billion", "*10^9")
expr = expr.replace("trillion", "*10^12")
for unit in [
"degree",
"cm",
"centimeter",
"meter",
"mile",
"second",
"minute",
"hour",
"day",
"week",
"month",
"year",
"foot",
"feet",
"inch",
"yard",
"liter",
]:
expr = re.sub(rf"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
expr = re.sub("\\^ *\\\\circ", "", expr)
if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
expr = expr[1:-1]
expr = re.sub(",\\\\! *", "", expr)
if _is_float(expr) and _is_int(float(expr)):
expr = str(int(round(float(expr))))
if "\\" in expr:
try:
expr = _parse_latex(expr)
except Exception:
pass
# edge case with mixed numbers and negative signs
expr = re.sub("- *", "-", expr)
expr = _inject_implicit_mixed_number(expr)
# don't be case sensitive for text answers
expr = expr.lower()
if _str_is_int(expr):
expr = str(_str_to_int(expr))
return expr
[docs]
def count_unknown_letters_in_expr(expr: str):
expr = expr.replace("sqrt", "")
expr = expr.replace("frac", "")
letters_in_expr = set([x for x in expr if x.isalpha()])
return len(letters_in_expr)
[docs]
def should_allow_eval(expr: str):
# we don't want to try parsing unknown text or functions of more than two variables
if count_unknown_letters_in_expr(expr) > 2:
return False
for bad_string in BAD_SUBSTRINGS:
if bad_string in expr:
return False
for bad_regex in BAD_REGEXES:
if re.search(bad_regex, expr) is not None:
return False
return True
# @timeout(timeout_seconds=10)
[docs]
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
@timeout(timeout_seconds=10)
def check_equal():
memory_size = 1024**3
resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size))
expr = f"({ground_truth_normalized})-({given_normalized})"
if should_allow_eval(expr):
sympy_diff = _sympy_parse(expr)
simplified = sympy.simplify(sympy_diff)
if simplified == 0:
return True
return False
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
future = executor.submit(check_equal)
try:
return future.result(timeout=10)
except (concurrent.futures.TimeoutError, Exception):
future.cancel()
return False
[docs]
def split_tuple(expr: str):
"""
Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
"""
expr = _strip_properly_formatted_commas(expr)
if len(expr) == 0:
return []
if (
len(expr) > 2
and expr[0] in TUPLE_CHARS
and expr[-1] in TUPLE_CHARS
and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
):
elems = [elem.strip() for elem in expr[1:-1].split(",")]
else:
elems = [expr]
return elems
[docs]
def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]:
"""
The answer will be considered correct if:
(a) it normalizes to the same string as the ground truth answer
OR
(b) sympy can simplify the difference between the expressions to 0
"""
if given_answer is None:
return False
from verl.utils.reward_score.prime_math import math_normalize
ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)
given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)
# be at least as lenient as mathd
if ground_truth_normalized_mathd == given_answer_normalized_mathd:
return True, given_answer_normalized_mathd
ground_truth_normalized = _normalize(ground_truth)
given_normalized = _normalize(given_answer)
if ground_truth_normalized is None:
return False, given_normalized
if ground_truth_normalized == given_normalized:
return True, given_normalized
if len(given_normalized) == 0:
return False, given_normalized
ground_truth_elems = split_tuple(ground_truth_normalized)
given_elems = split_tuple(given_normalized)
if len(ground_truth_elems) > 1 and (
ground_truth_normalized[0] != given_normalized[0]
or ground_truth_normalized[-1] != given_normalized[-1]
):
is_correct = False
elif len(ground_truth_elems) != len(given_elems):
is_correct = False
else:
for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
if _is_frac(ground_truth_elem) and _is_frac(given_elem):
# if fractions aren't reduced, then shouldn't be marked as correct
# so, we don't want to allow sympy.simplify in this case
is_correct = ground_truth_elem == given_elem
elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
# if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
is_correct = False
else:
is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
if not is_correct:
break
return is_correct, given_normalized
def _last_boxed_only_string(string):
"""Strictly extract content from \boxed{}."""
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
left_brace_idx = None
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if left_brace_idx is None:
left_brace_idx = i
elif string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if left_brace_idx is None or right_brace_idx is None:
return None
return string[left_brace_idx + 1 : right_brace_idx].strip()
[docs]
def match_answer(response):
is_matched = False
response = response.split("</think>")[-1]
# Find boxed
ans_boxed = _last_boxed_only_string(response)
if ans_boxed:
is_matched = True
response = ans_boxed
return is_matched, response
[docs]
def compute_score(solution_str: str, ground_truth: str) -> float:
"""Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions
Args:
solution_str: The solution string
ground_truth: The ground truth answer
extra_info: dict with additional info for the score computation
Returns:
Reward score (1.0 for correct, 0.0 for incorrect)
"""
from verl.utils.reward_score.prime_math.grader import math_equal
# First assert intended generation and gt type
model_output = str(solution_str)
ground_truth = str(ground_truth)
# Extract answer from generated output
is_matched, extracted_model_output = match_answer(model_output)
# TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score
# Verify the solution, first check simple comparisons.
correct, pred = grade_answer(extracted_model_output, ground_truth)
if not correct:
try:
if "\\pi" in extracted_model_output or "\\pi" in ground_truth:
equivs = []
for pi in [math.pi, 3.14]:
equivs.append(
math_equal(extracted_model_output, ground_truth, tiemout=True, pi=pi)
)
correct = any(equivs)
else:
correct = math_equal(extracted_model_output, ground_truth, timeout=True)
except Exception:
correct = False
return 1.0 if correct else 0.0