Developer Guide
This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines.
Note
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
Creating New Workflows
Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow:
Step 0: Basic Concepts
Before starting development, it’s important to understand several core concepts:
Task (
trinity.common.workflows.Task
): Represents a data structure that can be converted into aWorkflow
. The content of theTask
varies depending on the task type:Math problems: A
Task
contains the problem description and the golden answer.Programming scenarios: A
Task
includes the problem description, test cases, runtime environment, and other complex information.
Workflow (
trinity.common.workflows.Workflow
): Can be understood as the running state of aTask
. It defines the interaction flow between Agents and Environments, including logic similar to Rollout and Reward calculations in other frameworks. After execution, it generates a list ofExperience
. Trinity-RFT includes several built-in workflows:MathWorkflow
(trinity.common.workflows.MathWorkflow
): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).WebShopWorkflow
(trinity.common.workflows.WebShopWorkflow
): For webshop scenarios, it contains multi-turn interaction with environment.CodeWorkflow
(Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.…
Experience (
trinity.common.experience.Experience
): The output of running aWorkflow
. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms,Experience
includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.
Step 1: Prepare Task Dataset
The task dataset is loaded via the buffer.explorer_input.taskset
configuration entry in your YAML config file.
To handle differences in Task
contents, Trinity-RFT provides a unified Task
interface containing the following fields.
workflow
(str
): The registered name of your workflow class. You can specify it inbuffer.explorer_input.taskset.default_workflow_type
of your YAML config file.reward_fn
(Optional[str]
): The registered name of your reward function. You can specify it inbuffer.explorer_input.taskset.default_reward_fn_type
. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.raw_task
(Dict
): An record of raw data inDict
format. For highly customized workflow, you can directly useraw_task
to initialize yourWorkflow
instance without relying on the following fields.format_args
(trinity.common.config.FormatConfig
): Parameters to facilitate the construction ofWorkflow
instances. For example, theprompt_key
andresponse_key
can be used to get the prompt and response fromraw_task
. These settings come from the YAML configuration file and can be set inbuffer.explorer_input.task_set.format
.rollout_args
(trinity.common.config.GenerationConfig
): Parameters that control the rollout process, such astemperature
. This field also comes from the YAML configuration file and can be set inbuffer.explorer_input.task_set.rollout_args
.
In the math problem scenario, the Task
dataset can be a jsonl
file, where each line contains JSON with question
and answer
fields representing the problem description and standard answer, respectively. For example:
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
Example configuration snippet:
# some config
buffer:
explorer_input:
taskset:
default_workflow: "math_workflow"
path: "/PATH/TO/FILE/DIR"
format:
prompt_key: "question"
response_key: "answer"
rollout_args:
temperature: 1.0
# some other configs
In this example, each task object’s raw_task
is a Dict
with two keys (question
and answer
). The MathWorkflow
uses the prompt_key
and response_key
to extract the question and answer from the raw_task
and use the rollout_args
to generate the response.
Step 2: Implement a New Workflow
The Workflow
base class interface is as follows:
class Workflow(ABC):
def __init__(
self,
model: ModelWrapper,
task: Task,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.model = model
self.auxiliary_models = auxiliary_models
@abstractmethod
def run(self) -> List[Experience]:
"""Run the workflow and return a list of Experiences."""
Initializing Your Workflow
During initialization, Workflow
receives the following parameters:
model
(trinity.common.models.model.ModelWrapper
): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply textresponse_text
, full sequence token idstokens
, prompt part token lengthprompt_length
, and a list of output token logprobslogprobs
).task
(trinity.common.workflows.Task
): A single data item from the task dataset.auxiliary_models
(List[openai.OpenAI]
):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
Tip
You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api
to true
in your config file and calling model.get_openai_client()
to get an openai.OpenAI
instance in your workflow.
Here’s an example of initializing a simple workflow using only raw_task
and rollout_args
. In more complex cases, you can use the format_args
for further customization.
class ExampleWorkflow(Workflow):
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
# Optional: If you want to use OpenAI API in your workflow
# self.openai_client = self.model.get_openai_client()
Implementing the run
method
The run
method is the core of your workflow. It returns a list of Experience
.
Below is a simple implementation for a math workflow.
We first call the model to generate multiple response using the provided question and rollout arguments.
Then we calculate the reward for each response using the calculate_reward
function.
Finally, we construct a list of Experience
with the responses and rewards and return it.
class ExampleWorkflow(Workflow):
# the __init__ function
def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences
Registering Your Workflow
Register your workflow using the WORKFLOWS.register_module
decorator.
Ensure the name does not conflict with existing workflows.
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
pass
Avoid Re-initialization
For heavy workflows, re-initializing every time can incurs extra computational costs.
In this case, you can implement the resettable
and reset
methods to avoid re-initialization.
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
# some code
# ...
def resettable(self):
return True
def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Full Code Example
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences
def resettable(self):
return True
def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Step 3: Use Your Workflow
After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type
in the buffer.explorer_input.taskset
domain to the newly registered Workflow
name.
buffer:
# Other fields
explorer_input:
taskset:
path: /path/to/taskset
default_workflow_type: example_workflow
# Other fields
Now you can run your workflow in Trinity-RFT using the command:
trinity run --config <your_yaml_file>
Adding New Config Entries for the Config Generator (Advanced)
Step 0: Understanding Streamlit
Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of Streamlit. This project primarily utilizes various input components from Streamlit and employs st.session_state
to store user-input parameters.
Step 1: Implement New Config Entries
To illustrate the process of creating a new parameter setting for the Config Generator page, we will use train_batch_size
as an example.
Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files:
trinity/manager/config_registry/buffer_config_manager.py
trinity/manager/config_registry/explorer_config_manager.py
trinity/manager/config_registry/model_config_manager.py
trinity/manager/config_registry/trainer_config_manager.py
In this case,
train_batch_size
should be placed in thebuffer_config_manager.py
file.Create a parameter setting function using Streamlit. The function name must follow the convention of starting with ‘set_’, and the remainder of the name becomes the config name.
Decorate the parameter setting function with the
CONFIG_GENERATORS.register_config
decorator. This decorator requires the following information:Default value of the parameter
Visibility condition (if applicable)
Additional config parameters (if needed)
Note
The CONFIG_GENERATORS.register_config
decorator automatically passes key=config_name
as an argument to the registered configuration function. Ensure that your function accepts this keyword argument.
For train_batch_size
, we will use the following settings:
Default value: 96
Visibility condition:
lambda: st.session_state["trainer_gpu_num"] > 0
Additional config:
{"_train_batch_size_per_gpu": 16}
Here’s the complete code for the train_batch_size
parameter:
@CONFIG_GENERATORS.register_config(
default_value=96,
visible=lambda: st.session_state["trainer_gpu_num"] > 0,
other_configs={"_train_batch_size_per_gpu": 16},
)
def set_train_batch_size(**kwargs):
key = kwargs.get("key")
trainer_gpu_num = st.session_state["trainer_gpu_num"]
st.session_state[key] = (
st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
)
def on_change():
st.session_state["_train_batch_size_per_gpu"] = max(
st.session_state[key] // st.session_state["trainer_gpu_num"], 1
)
st.number_input(
"Train Batch Size",
min_value=trainer_gpu_num,
step=trainer_gpu_num,
help=_str_for_train_batch_size(),
on_change=on_change,
**kwargs,
)
If the parameter requires validation, create a check function. For train_batch_size
, we need to ensure it is divisible by trainer_gpu_num
. If not, a warning should be displayed, and the parameter should be added to unfinished_fields
.
Decorate the check function with the CONFIG_GENERATORS.register_check
decorator:
@CONFIG_GENERATORS.register_check()
def check_train_batch_size(unfinished_fields: set, key: str):
if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
unfinished_fields.add(key)
st.warning(_str_for_train_batch_size())
Note
The CONFIG_GENERATORS.register_check
decorator automatically receives key=config_name
and unfinished_fields=self.unfinished_fields
as arguments. Ensure your function accepts these keyword arguments.
Step 2: Integrating New Parameters into config_manager.py
To successfully integrate new parameters into the config_manager.py
file, please adhere to the following procedure:
Parameter Categorization: Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes:
Beginner Mode: Comprises “Essential Configs” and “Important Configs” sections.
Expert Mode: Includes “Model”, “Buffer”, “Explorer and Synchronizer”, and “Trainer” sections.
Parameter Addition: Incorporate the new parameter into the relevant section using the
self.get_configs
method within theConfigManager
class.Example:
class ConfigManager: def _expert_buffer_part(self): self.get_configs("total_epochs", "train_batch_size")
YAML File Integration: Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the
generate_config
function and its associated sub-functions.Parameter Value Assignment: Utilize
st.session_state
to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.Example:
class ConfigManager: def _gen_buffer_config(self): buffer_config = { "batch_size": st.session_state["train_batch_size"], # Additional configuration parameters }
By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework.
Check Code Style
Before submitting the code, make sure it passes the code style check. Follow these steps:
# Install code style checking tools
cd <path_to_trinity_rft>
# bash
pip install -e .[dev]
# zsh
# pip install -e .\[dev\]
# Run code style checks
pre-commit --all-files
# Commit the code after all checks pass
git commit -am "create example workflow"