Workflow Development Guide#
In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. A qualified workflow needs to use a model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:
Step 0: Basic Concepts#
Before starting development, it’s important to understand several core concepts:
flowchart LR A([Task]) & B([Model]) --> C[Workflow] C --> D([Experience])
Task (
trinity.common.workflows.Task
): Represents a data structure that contains all the information needed for a single run of the workflow. Commonly provided by the training dataset, each sample in the dataset is converted into aTask
instance. The content of theTask
varies depending on the task type:Math problems: A
Task
contains the problem description and the golden answer.Programming scenarios: A
Task
includes the problem description, test cases, runtime environment, and other complex information.
Model (
trinity.common.models.model.ModelWrapper
): The model being trained. The workflow uses this model to generate responses based on the task. Trinity-RFT will provide the model instance to initialize the workflow.Workflow (
trinity.common.workflows.Workflow
): It defines the interaction flow between Agents and Environments. It uses theTask
to initialize itself and uses theModel
to generate responses. Different from general Agent Applications, aWorkflow
also needs to calculate rewards based on the environment’s feedback. Trinity-RFT provides several built-in workflows, including:MathWorkflow
(trinity.common.workflows.MathWorkflow
): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).WebShopWorkflow
(trinity.common.workflows.WebShopWorkflow
): For webshop scenarios, it contains multi-turn interaction with environment.AgentScopeReActWorkflow
(trinity.common.workflows.AgentScopeReActWorkflow
): It directly uses a pre-implemented ReActAgent (based on AgentScope) to solve tasks.
Experience (
trinity.common.experience.Experience
): The output of running aWorkflow
. The number and structure ofExperience
depend on the specific workflow. For example, for common PPO/GRPO algorithms,Experience
includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.
Step 1: Prepare Task Dataset#
The task dataset is loaded via the buffer.explorer_input.taskset
configuration entry in your YAML config file.
To handle differences in Task
contents, Trinity-RFT provides a unified Task
interface containing the following fields.
workflow
(str
): The registered name of your workflow class. You can specify it inbuffer.explorer_input.taskset.default_workflow_type
of your YAML config file.reward_fn
(Optional[str]
): The registered name of your reward function. You can specify it inbuffer.explorer_input.taskset.default_reward_fn_type
. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.raw_task
(Dict
): A record of raw data inDict
format. For highly customized workflow, you can directly useraw_task
to initialize yourWorkflow
instance without relying on the following fields.format_args
(trinity.common.config.FormatConfig
): Parameters to facilitate the construction ofWorkflow
instances. For example, theprompt_key
andresponse_key
can be used to get the prompt and response fromraw_task
. These settings come from the YAML configuration file and can be set inbuffer.explorer_input.task_set.format
.rollout_args
(trinity.common.config.GenerationConfig
): Parameters that control the rollout process, such astemperature
. This field also comes from the YAML configuration file and can be set inbuffer.explorer_input.task_set.rollout_args
.workflow_args
(Dict
): A dictionary of parameters to facilitate the construction ofWorkflow
instances. Provides more flexibility thanformat_args
androllout_args
by using a dictionary. This field also comes from the YAML configuration file and can be set inbuffer.explorer_input.task_set.workflow_args
. Normally, you do not need to set this field.
Tip
workflow
, workflow_args
and raw_task
provide different levels of customization.
workflow
provides the global settings for all tasks that uses the same workflow. (Global Level)workflow_args
can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)raw_task
provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)
In the math problem scenario, the Task
dataset can be a jsonl
file, where each line contains JSON with question
and answer
fields representing the problem description and standard answer, respectively. For example:
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
Example configuration snippet:
# some config
buffer:
explorer_input:
taskset:
default_workflow: "math_workflow"
path: ${oc.env:TRINITY_TASKSET_PATH}
format:
prompt_key: "question"
response_key: "answer"
rollout_args:
temperature: 1.0
# some other configs
In this example, each task object’s raw_task
is a Dict
with two keys (question
and answer
). The MathWorkflow
uses the prompt_key
and response_key
to extract the question and answer from the raw_task
and use the rollout_args
to generate the response.
Step 2: Implement a New Workflow#
The Workflow
base class interface is as follows:
class Workflow(ABC):
def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.task = task
self.model = model
self.auxiliary_models = auxiliary_models
@abstractmethod
def run(self) -> List[Experience]:
"""Run the workflow and return a list of Experiences."""
Initialize Your Workflow#
During initialization, Workflow
receives the following parameters:
task
(trinity.common.workflows.Task
): A single data item from the task dataset.model
(trinity.common.models.model.ModelWrapper
): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply textresponse_text
, full sequence token idstokens
, prompt part token lengthprompt_length
, and a list of output token logprobslogprobs
).auxiliary_models
(List[openai.OpenAI]
):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
Tip
You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api
to true
in your config file and calling model.get_openai_client()
to get an openai.OpenAI
instance in your workflow.
And the model
field when calling openai API can be obtained via openai_client.models.list().data[0].id
or openai_client.model_path
.
Here’s an example of initializing a simple workflow using only raw_task
and rollout_args
. In more complex cases, you can use the format_args
for further customization.
class ExampleWorkflow(Workflow):
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
# Optional: If you want to use OpenAI API in your workflow
# self.openai_client = self.model.get_openai_client()
Implementing the run
method#
The run
method is the core of your workflow. It returns a list of Experience
.
Below is a simple implementation for a math workflow.
We first call the model to generate a response using the provided question and rollout arguments.
Then we calculate the reward for each response using the calculate_reward
function.
Finally, we construct a list of Experience
with the responses and rewards and return it.
class ExampleWorkflow(Workflow):
# the __init__ function
def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
temperature=self.rollout_args.temperature,
)
response = responses[0] # there is only one response
reward: float = self.calculate_reward(response.response_text, self.answer)
return [
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
]
Registering Your Workflow#
Register your workflow using the WORKFLOWS.register_module
decorator.
Ensure the name does not conflict with existing workflows.
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
pass
For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in trinity/common/workflows
folder, e.g., trinity/common/workflows/example_workflow.py
. And add the following line to trinity/common/workflows/__init__.py
:
# existing import lines
from trinity.common.workflows.example_workflow import ExampleWorkflow
__all__ = [
# existing __all__ lines
"ExampleWorkflow",
]
Performance Optimization#
Avoid Re-initialization#
For heavy workflows, re-initializing every time can incurs extra computational costs.
In this case, you can set the can_reset
property and implement reset
method to avoid re-initialization.
The can_reset
is a class property that indicates whether the workflow supports resetting.
The reset
method accepts a Task
parameter and resets the workflow’s internal state based on the new task.
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True
# some code
# ...
def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Support Batch Inference#
In many popular RL algorithms, multiple runs of the same task are required (e.g., GRPO). In such scenarios, you can directly use batch inference to obtain multiple responses for a single question to improve efficiency.
For this case, you can implement the can_repeat
property and set_repeat_times
method.
The can_repeat
is a class property that indicates whether the workflow supports multiple executions within the run
method.
The set_repeat_times
method accepts two parameters: repeat_times
specifies the number of times to execute within the run
method, and run_id_base
is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored).
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_repeat: bool = True
# some code
def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.repeat_times, # run multiple times in one call
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calculate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences
Full Code Example#
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True
can_repeat: bool = True
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences
def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Step 3: Use Your Workflow#
After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type
in the buffer.explorer_input.taskset
domain to the newly registered Workflow
name.
buffer:
# Other fields
explorer_input:
taskset:
path: /path/to/taskset
default_workflow_type: example_workflow
# Other fields
Now you can run your workflow in Trinity-RFT using the command:
trinity run --config <your_yaml_file>
Advanced Features#
Async Support#
The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set is_async
to True
, then implement the run_async
method. In this case, you no longer need to implement the run
method, and the initialization parameter auxiliary_models
will also change to List[openai.AsyncOpenAI]
, while other methods and properties remain changed.
@WORKFLOWS.register_module("example_workflow_async")
class ExampleWorkflowAsync(Workflow):
is_async: bool = True
async def run_async(self) -> List[Experience]:
# your async code here
# no need to implement run() method
Using OpenAI API#
Trinity-RFT provides an option to use the OpenAI API for model inference. You can enable this feature by setting explorer.rollout_model.enable_openai_api
to true
in your configuration file. This allows you to obtain an openai.OpenAI
instance via the get_openai_client
method of the model instance provided by Trinity-RFT.
Additionally, since the OpenAI API does not provide all the data required for training, you also need to set explorer.rollout_model.enable_history
to true
. This lets the framework automatically record data that can be used for training and convert it into a list of Experience
. You can extract these experiences using the extract_experience_from_history
method.
# example config snippet
explorer:
rollout_model:
enable_openai_api: true
enable_history: true
# Other fields
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.model = model
self.client: openai.OpenAI = self.model.get_openai_client()
# or async client
# self.client: openai.AsyncOpenAI = self.model.get_openai_async_client()
self.agent = MyAgent(openai_client=self.client)
def calculate_reward(self, response: str) -> float:
# your reward calculation logic
def run(self) -> List[Experience]:
# run your agent
response = self.agent.run()
# calculate reward
reward = self.calculate_reward(response)
# extract experiences from history recorded in self.model
experiences = self.model.extract_experience_from_history()
for exp in experiences:
exp.reward = reward
return experiences
Tip
Currently, the OpenAI API will only automatically record calls to
openai.OpenAI.chat.completions.create
andopenai.AsyncOpenAI.chat.completions.create
, and convert them intoExperience
objects. Streaming output is not supported.When calling
chat.completions.create
, themodel
field can be obtained viaopenai_client.models.list().data[0].id
oropenai_client.model_path
.For more complex workflow examples using the OpenAI API, refer to ReAct Agent Training.
LLM-as-a-judge Support#
LLM-as-a-judge is a common reward calculation method, especially suitable for open-ended tasks (such as programming, writing, etc.). In these scenarios, the Workflow needs to leverage an additional LLM to evaluate the answer quality and compute the reward signal.
To support this, Trinity-RFT provides an Auxiliary Models mechanism. Auxiliary models are a set of models not involved in training; the Workflow can use these models to assist with tasks, such as acting as a judge to calculate rewards.
You can specify one or more auxiliary models in the configuration file via the explorer.auxiliary_models
field. For example:
explorer:
auxiliary_models:
- model_path: Qwen/Qwen2.5-32B-Instruct
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
- model_path: Qwen/Qwen3-8B
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
Note that each auxiliary model will independently occupy tensor_parallel_size * engine_num
GPUs. Please configure according to your hardware resources. After enabling auxiliary models, the number of GPUs available to the Trainer is the total GPU count minus those occupied by all auxiliary models and the inference model being trained (rollout_model
).
The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding openai.OpenAI
or openai.AsyncOpenAI
instances (depending on the is_async
setting) to the auxiliary_models
parameter of the Workflow
initialization method. For example:
class MyWorkflow(Workflow):
def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.judge_model = self.auxiliary_models[0] # Use the first auxiliary model as the judge
def run(self) -> List[Experience]:
response = self.do_something()
reward_response = self.judge_model.chat.completions.create(
model=self.judge_model.model_path,
messages=[
{
"role": "system",
"content": "You are a judge. You need to give a score from 0 to 1 based on the quality of the answer.",
},
{
"role": "user",
"content": f"Question:\n{self.task.raw_task['question']}\nAnswer:\n{response.response_text}\nPlease give a score from 0 to 1.",
},
],
temperature=0.0,
max_tokens=10,
)
# Parse the reward score
reward = float(reward_response.choices[0].message.content.strip())
return [
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
]
Debug Mode#
During Workflow development, repeatedly launching the full training process for testing is time-consuming and inefficient. To address this, Trinity-RFT provides a Debug Mode for developers. This mode leverages a pre-launched inference model to quickly run specified workflows and obtain results, avoiding repeated model loading and initialization delays, and significantly improving development efficiency. The process is illustrated below:
flowchart LR A[Start Inference Model] --> B[Debug Workflow] B --> B
To start the inference model, use the following command:
trinity debug --config <config_file_path> --module inference_model
Here, <config_file_path>
is the path to a YAML configuration file, which should follow the same format as the one used by the trinity run
command. The explorer.rollout_model
and explorer.auxiliary_models
fields in the config will be loaded to initialize the inference model.
Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow:
trinity debug --config <config_file_path> --module workflow --output_file <output_file_path> --plugin_dir <plugin_dir>
<config_file_path>
: Path to the YAML configuration file, usually the same as used for starting the inference model.<output_file_path>
: Path to save the performance profiling results. Debug Mode uses viztracer to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser.<plugin_dir>
(optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules.
During debugging, the buffer.explorer_input.taskset
field in the config will be loaded to initialize the workflow’s required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow’s return value will be automatically formatted and printed in the terminal for easy inspection.
When debugging is complete, you can terminate the inference model by pressing Ctrl+C
in its terminal.