In this document, we demonstrate how to implement and train, from scratch, an agent that can use Python to perform calculations and solve 'gsm8k' math problems.
- Define agent workflow Create your agent using AgentScope/Langchain/OpenaiSDK or only http requests, wrap it in a Workflow class.
- Define reward Configure how the agent's outputs are evaluated and scored.
- Prepare dataset Set up the dataset and configure the task reader.
- Debug (Optional) Test your workflow in debug mode before full training.
- Start training Launch the training process and track progress.
Checkout the full code of this example by clicking here
Step 1: ✨Define agent Workflow + Reward
First of all, create a directory for this training project:
Next, define your workflow (or convert an existing workflow). Here we use AgentScope to implement this agent. You can toggle two code before and after convertion to see the difference. If you prefer langchain or openai sdk, please refer to this article.
class MathToolWorkflow(Workflow): # ✨✨ inherit `Workflow` class
name: str = "math_agent_workflow"
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
# run agentscope
query = workflow_task.task.main_query
self.toolkit = Toolkit()
self.toolkit.register_tool_function(execute_python_code)
self.agent = ReActAgent(
name="math_react_agent", sys_prompt=system_prompt,
model=tuner.as_agentscope_model(), # ✨✨ compared with a normal agentscope agent, here is the difference!
formatter=DashScopeChatFormatter(),
toolkit=self.toolkit,
memory=InMemoryMemory(), max_iters=2,
)
self.agent.set_console_output_enabled(False)
msg = Msg("user", query, role="user")
result = await self.agent.reply(msg)
final_answer = extract_final_answer(result)
# compute reward
reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip()
match = re.search(r"\\boxed\{([^}]*)\}", final_answer)
if match: is_success = (match.group(1) == reference_answer)
else: is_success = False
return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer})
class MathToolWorkflow(object):
name: str = "math_agent_workflow"
async def execute(self, workflow_task: WorkflowTask) -> WorkflowOutput:
# run agentscope
query = workflow_task.task.main_query
self.toolkit = Toolkit()
self.toolkit.register_tool_function(execute_python_code)
self.agent = ReActAgent(
name="math_react_agent", sys_prompt=system_prompt,
model=DashScopeChatModel(model='qwen-max'),
formatter=DashScopeChatFormatter(),
toolkit=self.toolkit,
memory=InMemoryMemory(), max_iters=2,
)
self.agent.set_console_output_enabled(False)
msg = Msg("user", query, role="user")
result = await self.agent.reply(msg)
final_answer = extract_final_answer(result)
# compute reward
reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip()
match = re.search(r"\\boxed\{([^}]*)\}", final_answer)
if match: is_success = (match.group(1) == reference_answer)
else: is_success = False
return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer})
Step 2: ✨Prepare dataset
Data Sources
AgentJet provides multiple ways to read data:
- Read from local files on disk
- Read from a Hugging Face repo
- Read from an EnvService
Download the openai/gsm8k dataset:
Now, we have obtained all materials required to train the agent.
# ------------------ main configuration ------------------
ajet:
project_name: example_math_agent
task_reader:
type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo`
# effective when `type: huggingface_dat_repo`
huggingface_dat_repo:
dataset_path: 'openai/gsm8k'
training_split: "train"
validation_split: "test"
task_judge:
# ✨✨✨✨ null, because in this certain case, we write reward function together with workflow
judge_protocol: null
model:
# ✨✨✨✨ set the model to be trained
path: Qwen/Qwen2.5-7B
rollout:
user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ write and select workflow
num_repeat: 6 # grpo `n`
tensor_model_parallel_size: 1 # vllm tp
max_response_length_in_one_turn: 1024
max_model_len: 10000
data:
train_batch_size: 100
max_prompt_length: 3000
max_response_length: 7000
debug:
debug_max_parallel: 1
debug_first_n_tasks: 1
trainer_common:
save_freq: 100
test_freq: 100
total_epochs: 100
logger: swanlab
# ------------------ do not modify ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl
- file://ajet/default_config/trinity
# ------------------ do not modify ------------------
defaults:
- verl_default
- trinity_default
- ajet_default
- _self_
Configuration Parameters
| Category | Parameter | Description | Example Value |
|---|---|---|---|
| Project | project_name |
Name of the training project | example_math_agent |
| Task Reader | type |
Type of data source to read tasks from | huggingface_dat_repo (options: env_service, dataset_file, huggingface_dat_repo) |
dataset_path |
Path or identifier of the dataset | openai/gsm8k |
|
training_split |
Dataset split used for training | train |
|
validation_split |
Dataset split used for validation/testing | test |
|
| Model | path |
Path or identifier of the model to be trained | Qwen/Qwen2.5-7B |
| Rollout | user_workflow |
Python module path to the workflow class | tutorial.example_math_agent.math_agent->ExampleMathLearn |
num_repeat |
Number of rollout repeats per task (GRPO n parameter) |
6 |
|
tensor_model_parallel_size |
vLLM tensor parallelism size | 1 |
|
max_response_length_in_one_turn |
Maximum token length for a single agent response | 1024 |
|
max_model_len |
Maximum total context length for the model | 10000 |
|
| Data | train_batch_size |
Number of tasks per training batch | 100 |
max_prompt_length |
Maximum token length for input prompts | 3000 |
|
max_response_length |
Maximum token length for model responses | 7000 |
|
| Debug | debug_max_parallel |
Maximum parallel workers in debug mode | 1 |
debug_first_n_tasks |
Number of tasks to process in debug mode | 1 |
|
| Trainer | save_freq |
Frequency (in steps) to save model checkpoints | 100 |
test_freq |
Frequency (in steps) to run validation | 100 |
|
total_epochs |
Total number of training epochs | 100 |
|
logger |
Logging backend for experiment tracking | swanlab |
|
| Task Judge | judge_protocol |
Protocol for judging task completion | null (reward is computed in workflow) |
Step 3: ✨Debug (Optional)
Before full training, you can run some test in debug mode, using raw base model to test whether bug exists. We choose VSCode to debug because it is open-source and fast.
VS Code Debugging
- You can create
.vscode/launch.jsonfor breakpoint debugging:
After .vscode/launch.json is created, press F5 to start debugging. (Do not forget to configure python venv path in VSCode.)
For more debugging techniques, please refer to debugging guidelines.
Step 4: ✨Start Training
After debugging, launch the full training:
Output Location
Training logs and checkpoints will be saved default to:
Full Code
import re
from loguru import logger
from agentscope.message import Msg
from agentscope.agent import ReActAgent
from agentscope.formatter import DashScopeChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.tool import Toolkit, execute_python_code
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
def extract_final_answer(result) -> str:
"""Extract the final answer from the agent's response."""
try:
if (
hasattr(result, "metadata")
and isinstance(result.metadata, dict)
and "result" in result.metadata
):
return result.metadata["result"]
if hasattr(result, "content"):
if isinstance(result.content, dict) and "result" in result.content:
return result.content["result"]
return str(result.content)
return str(result)
except Exception as e:
logger.warning(f"Extract final answer error: {e}. Raw: {result}")
return str(result)
system_prompt = """
You are an agent specialized in solving math problems with tools.
Please solve the math problem given to you.
You can write and execute Python code to perform calculation or verify your answer.
You should return your final answer within \\boxed{{}}.
"""
class MathToolWorkflow(Workflow): # ✨✨ inherit `Workflow` class
name: str = "math_agent_workflow"
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
# run agentscope
query = workflow_task.task.main_query
self.toolkit = Toolkit()
self.toolkit.register_tool_function(execute_python_code)
self.agent = ReActAgent(
name="math_react_agent", sys_prompt=system_prompt,
model=tuner.as_agentscope_model(), # ✨✨ compared with a normal agentscope agent, here is the difference!
formatter=DashScopeChatFormatter(),
toolkit=self.toolkit,
memory=InMemoryMemory(), max_iters=2,
)
self.agent.set_console_output_enabled(False)
msg = Msg("user", query, role="user")
result = await self.agent.reply(msg)
final_answer = extract_final_answer(result)
# compute reward
reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip()
match = re.search(r"\\boxed\{([^}]*)\}", final_answer)
if match: is_success = (match.group(1) == reference_answer)
else: is_success = False
return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer})
# ------------------ main configuration ------------------
ajet:
project_name: example_math_agent
task_reader:
type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo`
# effective when `type: huggingface_dat_repo`
huggingface_dat_repo:
dataset_path: 'openai/gsm8k' # '/mnt/data_cpfs/dataset_cache/openai/gsm8k/main'
training_split: "train"
validation_split: "test"
model:
# ✨✨✨✨ set the model to be trained
path: Qwen/Qwen2___5-7B-Instruct # /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct
rollout:
user_workflow: "tutorial/example_math_agent/math_agent.py->MathToolWorkflow" # ✨✨✨✨ write and select workflow
num_repeat: 6 # grpo `n`
tensor_model_parallel_size: 1 # vllm tp
max_response_length_in_one_turn: 1024
max_model_len: 10000
task_judge:
# ✨✨✨✨ null, because in this certain case, we write reward function together with workflow
judge_protocol: null
data:
train_batch_size: 100
max_prompt_length: 3000
max_response_length: 7000
debug:
debug_max_parallel: 1
debug_first_n_tasks: 1
trainer_common:
save_freq: 100
test_freq: 100
total_epochs: 100
logger: swanlab
# ------------------ do not modify ------------------
hydra:
searchpath:
- file://ajet/default_config
- file://ajet/default_config/verl
- file://ajet/default_config/trinity
# ------------------ do not modify ------------------
defaults:
- verl_default
- trinity_default
- ajet_default
- _self_