Source code for trinity.common.workflows.customized_toolcall_workflows

# -*- coding: utf-8 -*-
"""
We include the customized toolcall workflows in this file.
Code adapted from https://github.com/NVlabs/Tool-N1
Reference Paper https://arxiv.org/pdf/2505.00024 for further details.
"""

import json
import re
from collections import Counter
from typing import List

from trinity.common.experience import Experience
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
from trinity.utils.log import get_logger

logger = get_logger(__name__)

# Adapted from https://github.com/NVlabs/Tool-N1
qwen_tool_prompts = """# Tool

<tools>
{tools}
</tools>

In each action step, you MUST:
1. Think about the reasoning process in the mind and enclosed your reasoning within <think> </think> XML tags.
2. Then, provide a json object with function names and arguments within <tool_call></tool_call> XML tags. i.e., <tool_call>[{{"name": <function-name>, "arguments": <args-json-object>}}, {{"name": <function-name2>, "arguments": <args-json-object2>}}, ...]</tool_call>
3. Make sure both the reasoning and the tool call steps are included together in one single reply.
A complete reply example is: <think>To address the query, I need to send the email to Bob and then buy the banana through walmart. </think> <tool_call> [{{"name": "email", "arguments": {{"receiver": "Bob", "content": "I will bug banana through walmart"}}}}, {{"name": "walmart", "arguments": {{"input": "banana"}}}}]</tool_call>. Please make sure the type of the arguments is correct.
"""


# Adapted from https://github.com/NVlabs/Tool-N1
[docs] def construct_prompt(dp): def format_tools(tools): tools = json.loads(tools) string = "" for tool in tools: string += json.dumps({"type": "function", "function": tool}) + "\n" if string[-1] == "\n": string = string[:-1] return string tools = format_tools(dp["tools"]) tool_prompt = qwen_tool_prompts.format(tools=tools) system = dp["raw_system"] conversations = dp["conversations"] prompt = [] prompt.append({"role": "system", "content": system + tool_prompt}) for tem in conversations: if tem["from"] == "human" or tem["from"] == "user": prompt.append({"role": "user", "content": tem["value"]}) elif tem["from"] == "gpt" or tem["from"] == "assistant": prompt.append({"role": "assistant", "content": tem["value"]}) elif tem["from"] == "observation" or tem["from"] == "tool": prompt.append({"role": "tool", "content": tem["value"]}) elif tem["from"] == "function_call": prompt.append({"role": "assistant", "content": json.dumps(tem["value"])}) return prompt
# Adapted from https://github.com/NVlabs/Tool-N1
[docs] def validate_result(result, answer): if len(result) == 0 or len(answer) == 0: if len(result) == len(answer): return 2 else: return 0 else: try: counter1_full = Counter( (item["name"], json.dumps(item["arguments"], sort_keys=True)) for item in result ) counter2_full = Counter( (item["name"], json.dumps(item["arguments"], sort_keys=True)) for item in answer ) except TypeError: return 0 if counter1_full == counter2_full: return 2 counter1_name = Counter(item["name"] for item in result) counter2_name = Counter(item["name"] for item in answer) if counter1_name == counter2_name: return 1 return 0
# Adapted from https://github.com/NVlabs/Tool-N1
[docs] def validate_format(tool_call_list): for item in tool_call_list: if not isinstance(item, dict): return 0 for item in tool_call_list: if "name" not in item.keys() or "arguments" not in item.keys(): return 0 return 1
# Adapted from https://github.com/NVlabs/Tool-N1
[docs] def extract_solution_v0(tool_call_str): output_string = tool_call_str pattern = r"<tool_call>(.*?)</tool_call>" matches = list(re.finditer(pattern, tool_call_str, flags=re.DOTALL)) if not matches: return None, output_string last_content = matches[-1].group(1).strip() try: return json.loads(last_content), output_string except json.JSONDecodeError: return None, output_string
[docs] def compute_score_v0( # noqa: C901 solution_str, ground_truth, do_print=False, ): answer = json.loads(ground_truth) result, output_string = extract_solution_v0(solution_str) if isinstance(result, str): try: result = json.loads(result) except json.JSONDecodeError: result = None if isinstance(result, dict): tem = [] tem.append(result) result = tem if isinstance(answer, str): answer = json.loads(answer) if do_print: print("************solution_str************") print(solution_str) print(f"Extracted result: {result}") print(f"Solution string: {answer}") if result is not None: if "<think>" not in output_string or "</think>" not in output_string: if do_print: print("--------" * 5 + "\n\n") print("not thinking:", -1) return 0 if result is None: if do_print: print("--------" * 5 + "\n\n") print("result is None:", -1) return 0 # added rule1 if solution_str.count("<think>") != 1 or solution_str.count("</think>") != 1: if do_print: print("--------" * 5 + "\n\n") print( f"Fail, think tag appear not once: " f"<think> appear {solution_str.count('<think>')} times, " f"</think> appear {solution_str.count('</think>')} times", -1, ) return 0 # added rule2 think_end_pos = solution_str.find("</think>") tool_call_start_pos = solution_str.find("<tool_call>") if tool_call_start_pos != -1: if think_end_pos > tool_call_start_pos: if do_print: print("--------" * 5 + "\n\n") print("Fail: <think> tag must before <tool_call> tag", -1) return 0 if not validate_format(result): if do_print: print("--------" * 5 + "\n\n") print("result wrong formate:", -1) return 0 if validate_result(result, answer) == 2: if do_print: print("--------" * 5 + "\n\n") print("get full core:", 1) return 1 else: if do_print: print("--------" * 5 + "\n\n") print("wrong answer", -1) return 0
[docs] def compute_toolcall_reward( solution_str: str, ground_truth: str, ) -> float: res = compute_score_v0(solution_str, ground_truth) if isinstance(res, (int, float, bool)): return float(res) else: return float(res[0])
[docs] @WORKFLOWS.register_module("toolcall_workflow") class ToolCallWorkflow(SimpleWorkflow): """ A workflow for toolcall tasks. Prompt construction and reward function from https://github.com/NVlabs/Tool-N1 Only support qwen model for now. You can change the prompt construction and reward calculation by yourself for other models. """
[docs] def reset(self, task: Task): self.format_args = task.format_args self.system_prompt = task.format_args.system_prompt self.reply_prefix = task.format_args.reply_prefix self.raw_task = task.raw_task self.task_desc = task.task_desc self.truth = task.truth self.workflow_args = task.workflow_args self.reward_fn_args = task.reward_fn_args
[docs] def format_prompt(self): raw_task = self.raw_task messages = construct_prompt(raw_task) return messages
[docs] def run(self) -> List[Experience]: messages = self.format_prompt() logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) for i, response in enumerate(responses): reward = 0.0 if self.raw_task is not None: ground_truth = self.raw_task.get("answer") if ground_truth is not None: reward = compute_toolcall_reward( solution_str=response.response_text, ground_truth=ground_truth, ) else: logger.error( "Key 'answer' not found in self.raw_task. Assigning default reward." ) else: logger.error("self.raw_task is None. Assigning default reward.") logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) response.reward = reward response.eid.run = i + self.run_id_base return responses