# -*- coding: utf-8 -*-
"""
We include the customized toolcall workflows in this file.
Code adapted from https://github.com/NVlabs/Tool-N1
Reference Paper https://arxiv.org/pdf/2505.00024 for further details.
"""
import json
import re
from collections import Counter
from typing import List
from trinity.common.experience import Experience
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
from trinity.utils.log import get_logger
logger = get_logger(__name__)
# Adapted from https://github.com/NVlabs/Tool-N1
qwen_tool_prompts = """# Tool
<tools>
{tools}
</tools>
In each action step, you MUST:
1. Think about the reasoning process in the mind and enclosed your reasoning within <think> </think> XML tags.
2. Then, provide a json object with function names and arguments within <tool_call></tool_call> XML tags. i.e., <tool_call>[{{"name": <function-name>, "arguments": <args-json-object>}}, {{"name": <function-name2>, "arguments": <args-json-object2>}}, ...]</tool_call>
3. Make sure both the reasoning and the tool call steps are included together in one single reply.
A complete reply example is: <think>To address the query, I need to send the email to Bob and then buy the banana through walmart. </think> <tool_call> [{{"name": "email", "arguments": {{"receiver": "Bob", "content": "I will bug banana through walmart"}}}}, {{"name": "walmart", "arguments": {{"input": "banana"}}}}]</tool_call>. Please make sure the type of the arguments is correct.
"""
# Adapted from https://github.com/NVlabs/Tool-N1
[docs]
def construct_prompt(dp):
def format_tools(tools):
tools = json.loads(tools)
string = ""
for tool in tools:
string += json.dumps({"type": "function", "function": tool}) + "\n"
if string[-1] == "\n":
string = string[:-1]
return string
tools = format_tools(dp["tools"])
tool_prompt = qwen_tool_prompts.format(tools=tools)
system = dp["raw_system"]
conversations = dp["conversations"]
prompt = []
prompt.append({"role": "system", "content": system + tool_prompt})
for tem in conversations:
if tem["from"] == "human" or tem["from"] == "user":
prompt.append({"role": "user", "content": tem["value"]})
elif tem["from"] == "gpt" or tem["from"] == "assistant":
prompt.append({"role": "assistant", "content": tem["value"]})
elif tem["from"] == "observation" or tem["from"] == "tool":
prompt.append({"role": "tool", "content": tem["value"]})
elif tem["from"] == "function_call":
prompt.append({"role": "assistant", "content": json.dumps(tem["value"])})
return prompt
# Adapted from https://github.com/NVlabs/Tool-N1
[docs]
def validate_result(result, answer):
if len(result) == 0 or len(answer) == 0:
if len(result) == len(answer):
return 2
else:
return 0
else:
try:
counter1_full = Counter(
(item["name"], json.dumps(item["arguments"], sort_keys=True)) for item in result
)
counter2_full = Counter(
(item["name"], json.dumps(item["arguments"], sort_keys=True)) for item in answer
)
except TypeError:
return 0
if counter1_full == counter2_full:
return 2
counter1_name = Counter(item["name"] for item in result)
counter2_name = Counter(item["name"] for item in answer)
if counter1_name == counter2_name:
return 1
return 0
# Adapted from https://github.com/NVlabs/Tool-N1
# Adapted from https://github.com/NVlabs/Tool-N1
[docs]
def compute_score_v0( # noqa: C901
solution_str,
ground_truth,
do_print=False,
):
answer = json.loads(ground_truth)
result, output_string = extract_solution_v0(solution_str)
if isinstance(result, str):
try:
result = json.loads(result)
except json.JSONDecodeError:
result = None
if isinstance(result, dict):
tem = []
tem.append(result)
result = tem
if isinstance(answer, str):
answer = json.loads(answer)
if do_print:
print("************solution_str************")
print(solution_str)
print(f"Extracted result: {result}")
print(f"Solution string: {answer}")
if result is not None:
if "<think>" not in output_string or "</think>" not in output_string:
if do_print:
print("--------" * 5 + "\n\n")
print("not thinking:", -1)
return 0
if result is None:
if do_print:
print("--------" * 5 + "\n\n")
print("result is None:", -1)
return 0
# added rule1
if solution_str.count("<think>") != 1 or solution_str.count("</think>") != 1:
if do_print:
print("--------" * 5 + "\n\n")
print(
f"Fail, think tag appear not once: "
f"<think> appear {solution_str.count('<think>')} times, "
f"</think> appear {solution_str.count('</think>')} times",
-1,
)
return 0
# added rule2
think_end_pos = solution_str.find("</think>")
tool_call_start_pos = solution_str.find("<tool_call>")
if tool_call_start_pos != -1:
if think_end_pos > tool_call_start_pos:
if do_print:
print("--------" * 5 + "\n\n")
print("Fail: <think> tag must before <tool_call> tag", -1)
return 0
if not validate_format(result):
if do_print:
print("--------" * 5 + "\n\n")
print("result wrong formate:", -1)
return 0
if validate_result(result, answer) == 2:
if do_print:
print("--------" * 5 + "\n\n")
print("get full core:", 1)
return 1
else:
if do_print:
print("--------" * 5 + "\n\n")
print("wrong answer", -1)
return 0