# -*- coding: utf-8 -*-
"""
We include the customized toolcall workflows in this file.
Code adapted from https://github.com/NVlabs/Tool-N1
Reference Paper https://arxiv.org/pdf/2505.00024 for further details.
"""
import json
import re
from collections import Counter
from typing import List
from trinity.common.experience import Experience
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
# Adapted from https://github.com/NVlabs/Tool-N1
qwen_tool_prompts = """# Tool
<tools>
{tools}
</tools>
In each action step, you MUST:
1. Think about the reasoning process in the mind and enclosed your reasoning within <think> </think> XML tags.
2. Then, provide a json object with function names and arguments within <tool_call></tool_call> XML tags. i.e., <tool_call>[{{"name": <function-name>, "arguments": <args-json-object>}}, {{"name": <function-name2>, "arguments": <args-json-object2>}}, ...]</tool_call>
3. Make sure both the reasoning and the tool call steps are included together in one single reply.
A complete reply example is: <think>To address the query, I need to send the email to Bob and then buy the banana through walmart. </think> <tool_call> [{{"name": "email", "arguments": {{"receiver": "Bob", "content": "I will bug banana through walmart"}}}}, {{"name": "walmart", "arguments": {{"input": "banana"}}}}]</tool_call>. Please make sure the type of the arguments is correct.
"""
# Adapted from https://github.com/NVlabs/Tool-N1
def construct_prompt(dp):
def format_tools(tools):
tools = json.loads(tools)
string = ""
for tool in tools:
string += json.dumps({"type": "function", "function": tool}) + "\n"
if string[-1] == "\n":
string = string[:-1]
return string
tools = format_tools(dp["tools"])
tool_prompt = qwen_tool_prompts.format(tools=tools)
system = dp["raw_system"]
conversations = dp["conversations"]
prompt = []
prompt.append({"role": "system", "content": system + tool_prompt})
for tem in conversations:
if tem["from"] == "human" or tem["from"] == "user":
prompt.append({"role": "user", "content": tem["value"]})
elif tem["from"] == "gpt" or tem["from"] == "assistant":
prompt.append({"role": "assistant", "content": tem["value"]})
elif tem["from"] == "observation" or tem["from"] == "tool":
prompt.append({"role": "tool", "content": tem["value"]})
elif tem["from"] == "function_call":
prompt.append({"role": "assistant", "content": json.dumps(tem["value"])})
return prompt
# Adapted from https://github.com/NVlabs/Tool-N1
def validate_result(result, answer):
if len(result) == 0 or len(answer) == 0:
if len(result) == len(answer):
return 2
else:
return 0
else:
try:
counter1_full = Counter(
(item["name"], json.dumps(item["arguments"], sort_keys=True)) for item in result
)
counter2_full = Counter(
(item["name"], json.dumps(item["arguments"], sort_keys=True)) for item in answer
)
except TypeError:
return 0
if counter1_full == counter2_full:
return 2
counter1_name = Counter(item["name"] for item in result)
counter2_name = Counter(item["name"] for item in answer)
if counter1_name == counter2_name:
return 1
return 0
# Adapted from https://github.com/NVlabs/Tool-N1
def validate_format(tool_call_list):
for item in tool_call_list:
if not isinstance(item, dict):
return 0
for item in tool_call_list:
if "name" not in item.keys() or "arguments" not in item.keys():
return 0
return 1
# Adapted from https://github.com/NVlabs/Tool-N1
def extract_solution_v0(tool_call_str):
output_string = tool_call_str
pattern = r"<tool_call>(.*?)</tool_call>"
matches = list(re.finditer(pattern, tool_call_str, flags=re.DOTALL))
if not matches:
return None, output_string
last_content = matches[-1].group(1).strip()
try:
return json.loads(last_content), output_string
except json.JSONDecodeError:
return None, output_string
def compute_score_v0( # noqa: C901
solution_str,
ground_truth,
do_print=False,
):
answer = json.loads(ground_truth)
result, output_string = extract_solution_v0(solution_str)
if isinstance(result, str):
try:
result = json.loads(result)
except json.JSONDecodeError:
result = None
if isinstance(result, dict):
tem = []
tem.append(result)
result = tem
if isinstance(answer, str):
answer = json.loads(answer)
if do_print:
print("************solution_str************")
print(solution_str)
print(f"Extracted result: {result}")
print(f"Solution string: {answer}")
if result is not None:
if "<think>" not in output_string or "</think>" not in output_string:
if do_print:
print("--------" * 5 + "\n\n")
print("not thinking:", -1)
return 0
if result is None:
if do_print:
print("--------" * 5 + "\n\n")
print("result is None:", -1)
return 0
# added rule1
if solution_str.count("<think>") != 1 or solution_str.count("</think>") != 1:
if do_print:
print("--------" * 5 + "\n\n")
print(
f"Fail, think tag appear not once: "
f"<think> appear {solution_str.count('<think>')} times, "
f"</think> appear {solution_str.count('</think>')} times",
-1,
)
return 0
# added rule2
think_end_pos = solution_str.find("</think>")
tool_call_start_pos = solution_str.find("<tool_call>")
if tool_call_start_pos != -1:
if think_end_pos > tool_call_start_pos:
if do_print:
print("--------" * 5 + "\n\n")
print("Fail: <think> tag must before <tool_call> tag", -1)
return 0
if not validate_format(result):
if do_print:
print("--------" * 5 + "\n\n")
print("result wrong formate:", -1)
return 0
if validate_result(result, answer) == 2:
if do_print:
print("--------" * 5 + "\n\n")
print("get full core:", 1)
return 1
else:
if do_print:
print("--------" * 5 + "\n\n")
print("wrong answer", -1)
return 0
def compute_toolcall_reward(
solution_str: str,
ground_truth: str,
) -> float:
res = compute_score_v0(solution_str, ground_truth)
if isinstance(res, (int, float, bool)):
return float(res)
else:
return float(res[0])