# -*- coding: utf-8 -*-
import regex as re
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
[docs]
def simple_answer_parser(response: str) -> str:
from math_verify import parse
search_ans = re.search(r"<answer>(.*?)</answer>", response)
if search_ans:
response = search_ans.group(1)
return parse(response)
[docs]
def find_boxed_answer(raw_answer, timeout=10):
"""
Find answers from solutions where the answers are enclosed in LaTeX's `\\boxed` tag
Args:
raw_answer (`str`): raw answer from model
timeout (`int`): timeout in seconds for regex
Returns:
`str`: answer if found, otherwise None
"""
pattern = r"\\boxed\s*(({(?:\\.|[^{}]|(?2))*})|(.))"
res = re.findall(pattern, raw_answer, timeout=timeout)
if res:
answer = res[-1][0] # regard the last boxed as the answer
if answer.startswith("{"):
answer = answer[1:-1]
return answer
else:
return None
# copy from Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
# copy from Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
[docs]
def validate_equation(equation_str, available_numbers):
"""Validate that equation only uses available numbers and each number once."""
try:
# Extract all numbers from the equation
numbers_in_eq = [int(n) for n in re.findall(r"\d+", equation_str)]
# Check if all numbers in equation are available
available_numbers = sorted(available_numbers)
numbers_in_eq = sorted(numbers_in_eq)
# Each number should be used exactly once
return numbers_in_eq == available_numbers
except Exception as e: # noqa: F841
return False
# copy from Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
[docs]
def evaluate_equation(equation_str):
"""Safely evaluate the arithmetic equation using eval() with precautions."""
try:
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
allowed_pattern = r"^[\d+\-*/().\s]+$"
if not re.match(allowed_pattern, equation_str):
raise ValueError("Invalid characters in equation.")
# Evaluate the equation with restricted globals and locals
result = eval(equation_str, {"__builtins__": None}, {})
return result
except Exception as e: # noqa: F841
return None
[docs]
def validate_think_pattern(text):
"""Validate whether the <think> </think> tag is properly formatted."""
start_tag = "<think>"
end_tag = "</think>"
start_count = text.count(start_tag)
end_count = text.count(end_tag)
if start_count == 1 and end_count == 1:
start_pos = text.find(start_tag)
end_pos = text.find(end_tag)
if start_pos < end_pos:
return True
return False
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs]
def compute_score(solution_str, ground_truth) -> float:
retval = 0.0
try:
string_in_last_boxed = last_boxed_only_string(solution_str)
if string_in_last_boxed is not None:
answer = remove_boxed(string_in_last_boxed)
if is_equiv(answer, ground_truth):
retval = 1.0
except Exception as e:
print(e)
return retval
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
[docs]
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs]
def remove_boxed(s):
if "\\boxed " in s:
left = "\\boxed "
assert s[: len(left)] == left
return s[len(left) :]
left = "\\boxed{"
assert s[: len(left)] == left
assert s[-1] == "}"
return s[len(left) : -1]
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs]
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
retval = None if right_brace_idx is None else string[idx : right_brace_idx + 1]
return retval
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs]
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except: # noqa: E722
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs]
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except: # noqa: E722
return string
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs]
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs]
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs]
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string