Workflow Development Guide#
In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:
Step 0: Basic Concepts#
Before starting development, it’s important to understand several core concepts:
Task (
trinity.common.workflows.Task
): Represents a data structure that can be converted into aWorkflow
. The content of theTask
varies depending on the task type:Math problems: A
Task
contains the problem description and the golden answer.Programming scenarios: A
Task
includes the problem description, test cases, runtime environment, and other complex information.
Workflow (
trinity.common.workflows.Workflow
): Describes how aTask
is executed. It defines the interaction flow between Agents and Environments, including logic similar to Rollout and Reward calculations in other frameworks. After execution, it generates a list ofExperience
. Trinity-RFT includes several built-in workflows:MathWorkflow
(trinity.common.workflows.MathWorkflow
): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).WebShopWorkflow
(trinity.common.workflows.WebShopWorkflow
): For webshop scenarios, it contains multi-turn interaction with environment.AgentScopeReactMathWorkflow
(trinity.common.workflows.AgentScopeReactMathWorkflow
): For math scenarios, it directly uses a pre-implemented ReActAgent (based on AgentScope) to solve math problems.
Experience (
trinity.common.experience.Experience
): The output of running aWorkflow
. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms,Experience
includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.
Step 1: Prepare Task Dataset#
The task dataset is loaded via the buffer.explorer_input.taskset
configuration entry in your YAML config file.
To handle differences in Task
contents, Trinity-RFT provides a unified Task
interface containing the following fields.
workflow
(str
): The registered name of your workflow class. You can specify it inbuffer.explorer_input.taskset.default_workflow_type
of your YAML config file.reward_fn
(Optional[str]
): The registered name of your reward function. You can specify it inbuffer.explorer_input.taskset.default_reward_fn_type
. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.raw_task
(Dict
): A record of raw data inDict
format. For highly customized workflow, you can directly useraw_task
to initialize yourWorkflow
instance without relying on the following fields.format_args
(trinity.common.config.FormatConfig
): Parameters to facilitate the construction ofWorkflow
instances. For example, theprompt_key
andresponse_key
can be used to get the prompt and response fromraw_task
. These settings come from the YAML configuration file and can be set inbuffer.explorer_input.task_set.format
.rollout_args
(trinity.common.config.GenerationConfig
): Parameters that control the rollout process, such astemperature
. This field also comes from the YAML configuration file and can be set inbuffer.explorer_input.task_set.rollout_args
.workflow_args
(Dict
): A dictionary of parameters to facilitate the construction ofWorkflow
instances. Provides more flexibility thanformat_args
androllout_args
by using a dictionary. This field also comes from the YAML configuration file and can be set inbuffer.explorer_input.task_set.workflow_args
. Normally, you do not need to set this field.
Tip
workflow
, workflow_args
and raw_task
provide different levels of customization.
workflow
provides the global settings for all tasks that uses the same workflow. (Global Level)workflow_args
can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)raw_task
provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)
In the math problem scenario, the Task
dataset can be a jsonl
file, where each line contains JSON with question
and answer
fields representing the problem description and standard answer, respectively. For example:
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
Example configuration snippet:
# some config
buffer:
explorer_input:
taskset:
default_workflow: "math_workflow"
path: ${oc.env:TRINITY_TASKSET_PATH}
format:
prompt_key: "question"
response_key: "answer"
rollout_args:
temperature: 1.0
# some other configs
In this example, each task object’s raw_task
is a Dict
with two keys (question
and answer
). The MathWorkflow
uses the prompt_key
and response_key
to extract the question and answer from the raw_task
and use the rollout_args
to generate the response.
Step 2: Implement a New Workflow#
The Workflow
base class interface is as follows:
class Workflow(ABC):
def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.task = task
self.model = model
self.auxiliary_models = auxiliary_models
@abstractmethod
def run(self) -> List[Experience]:
"""Run the workflow and return a list of Experiences."""
Initialize Your Workflow#
During initialization, Workflow
receives the following parameters:
task
(trinity.common.workflows.Task
): A single data item from the task dataset.model
(trinity.common.models.model.ModelWrapper
): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply textresponse_text
, full sequence token idstokens
, prompt part token lengthprompt_length
, and a list of output token logprobslogprobs
).auxiliary_models
(List[openai.OpenAI]
):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
Tip
You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api
to true
in your config file and calling model.get_openai_client()
to get an openai.OpenAI
instance in your workflow.
And the model
field when calling openai API can be obtained via openai_client.models.list().data[0].id
or openai_client.model_path
.
Here’s an example of initializing a simple workflow using only raw_task
and rollout_args
. In more complex cases, you can use the format_args
for further customization.
class ExampleWorkflow(Workflow):
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
# Optional: If you want to use OpenAI API in your workflow
# self.openai_client = self.model.get_openai_client()
Implementing the run
method#
The run
method is the core of your workflow. It returns a list of Experience
.
Below is a simple implementation for a math workflow.
We first call the model to generate multiple response using the provided question and rollout arguments.
Then we calculate the reward for each response using the calculate_reward
function.
Finally, we construct a list of Experience
with the responses and rewards and return it.
class ExampleWorkflow(Workflow):
# the __init__ function
def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences
Registering Your Workflow#
Register your workflow using the WORKFLOWS.register_module
decorator.
Ensure the name does not conflict with existing workflows.
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
pass
For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in trinity/common/workflows
folder, e.g., trinity/common/workflows/example_workflow.py
. And add the following line to trinity/common/workflows/__init__.py
:
# existing import lines
from trinity.common.workflows.example_workflow import ExampleWorkflow
__all__ = [
# existing __all__ lines
"ExampleWorkflow",
]
Avoid Re-initialization#
For heavy workflows, re-initializing every time can incurs extra computational costs.
In this case, you can implement the resettable
and reset
methods to avoid re-initialization.
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
# some code
# ...
def resettable(self):
return True
def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Full Code Example#
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences
def resettable(self):
return True
def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Step 3: Use Your Workflow#
After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type
in the buffer.explorer_input.taskset
domain to the newly registered Workflow
name.
buffer:
# Other fields
explorer_input:
taskset:
path: /path/to/taskset
default_workflow_type: example_workflow
# Other fields
Now you can run your workflow in Trinity-RFT using the command:
trinity run --config <your_yaml_file>