Source code for trinity.utils.eval_utils

# -*- coding: utf-8 -*-
import regex as re

ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"


[docs] def simple_answer_parser(response: str) -> str: from math_verify import parse search_ans = re.search(r"<answer>(.*?)</answer>", response) if search_ans: response = search_ans.group(1) return parse(response)
[docs] def find_boxed_answer(raw_answer, timeout=10): """ Find answers from solutions where the answers are enclosed in LaTeX's `\\boxed` tag Args: raw_answer (`str`): raw answer from model timeout (`int`): timeout in seconds for regex Returns: `str`: answer if found, otherwise None """ pattern = r"\\boxed\s*(({(?:\\.|[^{}]|(?2))*})|(.))" res = re.findall(pattern, raw_answer, timeout=timeout) if res: answer = res[-1][0] # regard the last boxed as the answer if answer.startswith("{"): answer = answer[1:-1] return answer else: return None
# copy from Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
[docs] def extract_solution(solution_str): """Extract the equation from the solution string.""" solution_str = solution_str.split("\n")[-1] answer_pattern = r"<answer>(.*?)</answer>" match = re.finditer(answer_pattern, solution_str) matches = list(match) if matches: final_answer = matches[-1].group(1).strip() else: final_answer = None return final_answer
# copy from Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
[docs] def validate_equation(equation_str, available_numbers): """Validate that equation only uses available numbers and each number once.""" try: # Extract all numbers from the equation numbers_in_eq = [int(n) for n in re.findall(r"\d+", equation_str)] # Check if all numbers in equation are available available_numbers = sorted(available_numbers) numbers_in_eq = sorted(numbers_in_eq) # Each number should be used exactly once return numbers_in_eq == available_numbers except Exception as e: # noqa: F841 return False
# copy from Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
[docs] def evaluate_equation(equation_str): """Safely evaluate the arithmetic equation using eval() with precautions.""" try: # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace allowed_pattern = r"^[\d+\-*/().\s]+$" if not re.match(allowed_pattern, equation_str): raise ValueError("Invalid characters in equation.") # Evaluate the equation with restricted globals and locals result = eval(equation_str, {"__builtins__": None}, {}) return result except Exception as e: # noqa: F841 return None
[docs] def validate_think_pattern(text): """Validate whether the <think> </think> tag is properly formatted.""" start_tag = "<think>" end_tag = "</think>" start_count = text.count(start_tag) end_count = text.count(end_tag) if start_count == 1 and end_count == 1: start_pos = text.find(start_tag) end_pos = text.find(end_tag) if start_pos < end_pos: return True return False
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs] def compute_score(solution_str, ground_truth) -> float: retval = 0.0 try: string_in_last_boxed = last_boxed_only_string(solution_str) if string_in_last_boxed is not None: answer = remove_boxed(string_in_last_boxed) if is_equiv(answer, ground_truth): retval = 1.0 except Exception as e: print(e) return retval
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
[docs] def is_equiv(str1, str2, verbose=False): if str1 is None and str2 is None: print("WARNING: Both None") return True if str1 is None or str2 is None: return False try: ss1 = strip_string(str1) ss2 = strip_string(str2) if verbose: print(ss1, ss2) return ss1 == ss2 except Exception: return str1 == str2
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs] def remove_boxed(s): if "\\boxed " in s: left = "\\boxed " assert s[: len(left)] == left return s[len(left) :] left = "\\boxed{" assert s[: len(left)] == left assert s[-1] == "}" return s[len(left) : -1]
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs] def last_boxed_only_string(string): idx = string.rfind("\\boxed") if "\\boxed " in string: return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 retval = None if right_brace_idx is None else string[idx : right_brace_idx + 1] return retval
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs] def fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += "\\frac" if substr[0] == "{": new_str += substr else: try: assert len(substr) >= 2 except: # noqa: E722 return string a = substr[0] b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}{" + b + "}" + post_substr else: new_str += "{" + a + "}{" + b + "}" else: if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b string = new_str return string
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs] def fix_a_slash_b(string): if len(string.split("/")) != 2: return string a = string.split("/")[0] b = string.split("/")[1] try: a = int(a) b = int(b) assert string == "{}/{}".format(a, b) new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string except: # noqa: E722 return string
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs] def remove_right_units(string): # "\\text{ " only ever occurs (at least in the val set) when describing units if "\\text{ " in string: splits = string.split("\\text{ ") assert len(splits) == 2 return splits[0] else: return string
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs] def fix_sqrt(string): if "\\sqrt" not in string: return string splits = string.split("\\sqrt") new_string = splits[0] for split in splits[1:]: if split[0] != "{": a = split[0] new_substr = "\\sqrt{" + a + "}" + split[1:] else: new_substr = "\\sqrt" + split new_string += new_substr return new_string
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
[docs] def strip_string(string): # linebreaks string = string.replace("\n", "") # remove inverse spaces string = string.replace("\\!", "") # replace \\ with \ string = string.replace("\\\\", "\\") # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") # remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") # Remove circ (degrees) string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") # remove dollar signs string = string.replace("\\$", "") # remove units (on the right) string = remove_right_units(string) # remove percentage string = string.replace("\\%", "") string = string.replace("\%", "") # noqa: W605 # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") # if empty, return empty string if len(string) == 0: return string if string[0] == ".": string = "0" + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: string = string.split("=")[1] # fix sqrt3 --> sqrt{3} string = fix_sqrt(string) # remove spaces string = string.replace(" ", "") # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = fix_fracs(string) # manually change 0.5 --> \frac{1}{2} if string == "0.5": string = "\\frac{1}{2}" # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = fix_a_slash_b(string) return string