Developer Guide
This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines.
In Trinity-RFT, we decompose the RL pipeline into three main modules (Explorer, Trainer and Buffer) to facilitate customization and extension. Below is a table summarizing the modules and components that developers with different targets need to focus on.
Development Target |
Core Module |
Key Component |
---|---|---|
Apply existing RL algorithms to new environments. |
Explorer |
|
Design new RL algorithms. |
Trainer |
|
Enhance the RL process from the data perspective. |
Buffer |
Data Processing Module (Coming soon) |
Note
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
Workflows (For RL Environment Developers)
In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:
Step 0: Basic Concepts
Before starting development, it’s important to understand several core concepts:
Task (
trinity.common.workflows.Task
): Represents a data structure that can be converted into aWorkflow
. The content of theTask
varies depending on the task type:Math problems: A
Task
contains the problem description and the golden answer.Programming scenarios: A
Task
includes the problem description, test cases, runtime environment, and other complex information.
Workflow (
trinity.common.workflows.Workflow
): Describes how aTask
is executed. It defines the interaction flow between Agents and Environments, including logic similar to Rollout and Reward calculations in other frameworks. After execution, it generates a list ofExperience
. Trinity-RFT includes several built-in workflows:MathWorkflow
(trinity.common.workflows.MathWorkflow
): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).WebShopWorkflow
(trinity.common.workflows.WebShopWorkflow
): For webshop scenarios, it contains multi-turn interaction with environment.CodeWorkflow
(Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.…
Experience (
trinity.common.experience.Experience
): The output of running aWorkflow
. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms,Experience
includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.
Step 1: Prepare Task Dataset
The task dataset is loaded via the buffer.explorer_input.taskset
configuration entry in your YAML config file.
To handle differences in Task
contents, Trinity-RFT provides a unified Task
interface containing the following fields.
workflow
(str
): The registered name of your workflow class. You can specify it inbuffer.explorer_input.taskset.default_workflow_type
of your YAML config file.reward_fn
(Optional[str]
): The registered name of your reward function. You can specify it inbuffer.explorer_input.taskset.default_reward_fn_type
. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.raw_task
(Dict
): A record of raw data inDict
format. For highly customized workflow, you can directly useraw_task
to initialize yourWorkflow
instance without relying on the following fields.format_args
(trinity.common.config.FormatConfig
): Parameters to facilitate the construction ofWorkflow
instances. For example, theprompt_key
andresponse_key
can be used to get the prompt and response fromraw_task
. These settings come from the YAML configuration file and can be set inbuffer.explorer_input.task_set.format
.rollout_args
(trinity.common.config.GenerationConfig
): Parameters that control the rollout process, such astemperature
. This field also comes from the YAML configuration file and can be set inbuffer.explorer_input.task_set.rollout_args
.workflow_args
(Dict
): A dictionary of parameters to facilitate the construction ofWorkflow
instances. Provides more flexibility thanformat_args
androllout_args
by using a dictionary. This field also comes from the YAML configuration file and can be set inbuffer.explorer_input.task_set.workflow_args
. Normally, you do not need to set this field.
Tip
workflow
, workflow_args
and raw_task
provide different levels of customization.
workflow
provides the global settings for all tasks that uses the same workflow. (Global Level)workflow_args
can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)raw_task
provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)
In the math problem scenario, the Task
dataset can be a jsonl
file, where each line contains JSON with question
and answer
fields representing the problem description and standard answer, respectively. For example:
{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...
Example configuration snippet:
# some config
buffer:
explorer_input:
taskset:
default_workflow: "math_workflow"
path: "/PATH/TO/FILE/DIR"
format:
prompt_key: "question"
response_key: "answer"
rollout_args:
temperature: 1.0
# some other configs
In this example, each task object’s raw_task
is a Dict
with two keys (question
and answer
). The MathWorkflow
uses the prompt_key
and response_key
to extract the question and answer from the raw_task
and use the rollout_args
to generate the response.
Step 2: Implement a New Workflow
The Workflow
base class interface is as follows:
class Workflow(ABC):
def __init__(
self,
model: ModelWrapper,
task: Task,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.model = model
self.auxiliary_models = auxiliary_models
@abstractmethod
def run(self) -> List[Experience]:
"""Run the workflow and return a list of Experiences."""
Initialize Your Workflow
During initialization, Workflow
receives the following parameters:
model
(trinity.common.models.model.ModelWrapper
): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply textresponse_text
, full sequence token idstokens
, prompt part token lengthprompt_length
, and a list of output token logprobslogprobs
).task
(trinity.common.workflows.Task
): A single data item from the task dataset.auxiliary_models
(List[openai.OpenAI]
):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
Tip
You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api
to true
in your config file and calling model.get_openai_client()
to get an openai.OpenAI
instance in your workflow.
Here’s an example of initializing a simple workflow using only raw_task
and rollout_args
. In more complex cases, you can use the format_args
for further customization.
class ExampleWorkflow(Workflow):
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
# Optional: If you want to use OpenAI API in your workflow
# self.openai_client = self.model.get_openai_client()
Implementing the run
method
The run
method is the core of your workflow. It returns a list of Experience
.
Below is a simple implementation for a math workflow.
We first call the model to generate multiple response using the provided question and rollout arguments.
Then we calculate the reward for each response using the calculate_reward
function.
Finally, we construct a list of Experience
with the responses and rewards and return it.
class ExampleWorkflow(Workflow):
# the __init__ function
def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences
Registering Your Workflow
Register your workflow using the WORKFLOWS.register_module
decorator.
Ensure the name does not conflict with existing workflows.
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
pass
For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in trinity/common/workflows
folder, e.g., trinity/common/workflows/example_workflow.py
. And add the following line to trinity/common/workflows/__init__.py
:
# existing import lines
from .example_workflow import ExampleWorkflow
__all__ = [
# existing __all__ lines
"ExampleWorkflow",
]
For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in trinity/plugins
. Trinity-RFT will automatically detect and load all custom modules in this folder.
Tip
You can specify the directory where your custom modules are located by setting --plugin-dir
when starting Trinity-RFT. If you don’t specify --plugin-dir
, Trinity-RFT will use <Trinity_RFT_ROOT_DIR>/trinity/plugins
as the default directory.
Avoid Re-initialization
For heavy workflows, re-initializing every time can incurs extra computational costs.
In this case, you can implement the resettable
and reset
methods to avoid re-initialization.
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
# some code
# ...
def resettable(self):
return True
def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Full Code Example
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
def calculate_reward(self, response: str, truth: str) -> float:
if response == truth:
return 1.0
else:
return 0.0
def run(self) -> List[Experience]:
# call the model to generate multiple responses
responses = self.model.chat(
[
{
"role": "user",
"content": f"Question:\n{self.question}",
}
],
n=self.rollout_args.n,
temperature=self.rollout_args.temperature,
)
experiences = []
for response in responses:
# calulcate reward
reward: float = self.calculate_reward(response.response_text, self.answer)
# construct Experience
experiences.append(
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
)
return experiences
def resettable(self):
return True
def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Step 3: Use Your Workflow
After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type
in the buffer.explorer_input.taskset
domain to the newly registered Workflow
name.
buffer:
# Other fields
explorer_input:
taskset:
path: /path/to/taskset
default_workflow_type: example_workflow
# Other fields
Now you can run your workflow in Trinity-RFT using the command:
trinity run --config <your_yaml_file>
Algorithms (For RL Algorithm Developers)
Trinity-RFT provides a standardized process for implementing new algorithms.
Step 0: Basic Concepts of Algorithm Module
In Trinity-RFT, the algorithm module is primarily responsible for extracting experience data from the Replay Buffer during the RL process and calculating the loss to update models based on this data. To avoid implementing a new Trainer class each time a new algorithm is added, we have decomposed the representative PPO algorithm process into multiple sub-modules to adapt to various algorithms.
Sample Strategy (
trinity.algorithm.SampleStrategy
): Responsible for sampling experience data from the buffer module. By customizing this module, you can implement functionalities like filtering experience data or mixed sampling from multiple data sources.Advantage Fn(
trinity.algorithm.AdvantageFn
): Responsible for calculating the Advantage and Returns of experience data.Policy Loss Fn(
trinity.algorithm.PolicyLossFn
): Responsible for calculating the core training loss of the policy network.KL Fn(
trinity.algorithm.KLFn
): Responsible for calculating KL Divergence, which is generally used in two places in existing RL algorithms: Reward Penalty and Actor Loss.Entropy Loss Fn(
trinity.algorithm.EntropyLossFn
): Responsible for calculating the entropy loss of the policy network.
We provide several implementations of above modules in trinity/algorithm
.
Step 1: Implement Algorithm Components
Trinity-RFT allows developers to customize all the above modules. Developers only need to implement specific modules according to the requirements of their new algorithm. This section will provide a simple introduction using the OPMD algorithm as an example.
The main difference between OPMD and PPO algorithms lies in the calculation of Advantage and Policy Loss. Therefore, only new Advantage Fn and Policy Loss Fn modules need to be implemented.
Step 1.1: Implement AdvantageFn
Developers need to implement the trinity.algorithm.AdvantageFn
interface, which mainly includes two methods:
__call__
: Calculates advantages and returns based on input experience data, records observable metrics during the calculation process, and returns the experience data containing advantages and returns as well as a metrics dictionary. The input experience data format is verl’sDataProto
.default_args
: Returns default initialization parameters in dictionary form, which will be used by default when users don’t specify initialization parameters in the configuration file.
After implementation, you need to register this module through trinity.algorithm.ADVANTAGE_FN
. Once registered, the module can be configured in the configuration file using the registered name.
Here’s an implementation example for the OPMD algorithm’s Advantage Fn:
# trinity/algorithm/advantage_fn/opmd.py
# import some modules
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
@ADVANTAGE_FN.register_module("opmd")
class OPMDAdvantageFn(AdvantageFn):
"""OPMD advantage computation"""
def __init__(
self,
opmd_baseline: str = "mean",
tau: float = 1.0,
) -> None:
self.opmd_baseline = opmd_baseline
self.tau = tau
def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
# calculate advantages and returns based on the exps
# record some metrics
return exps, metrics
@classmethod
def default_args(cls) -> Dict:
return {
"opmd_baseline": "mean",
"tau": 1.0,
}
Step 1.2: Implement PolicyLossFn
Developers need to implement the trinity.algorithm.PolicyLossFn
interface, which is similar to AdvantageFn
and includes two methods:
__call__
: Calculates the loss based on input parameters. UnlikeAdvantageFn
, the input parameters here are alltorch.Tensor
. This interface automatically scans the parameter list of the__call__
method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters fromkwargs
.default_args
: Returns default initialization parameters in dictionary form, which will be used by default when users don’t specify initialization parameters in the configuration file.
Similarly, after implementation, you need to register this module through trinity.algorithm.POLICY_LOSS_FN
.
Here’s an implementation example for the OPMD algorithm’s Policy Loss Fn. Since OPMD’s Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the __call__
method:
@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
self.tau = tau
def __call__( # type: ignore
self,
logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
pg_losses = -advantages * logprob
opmd_loss = masked_mean(pg_losses, action_mask)
opmd_loss = opmd_loss / (1.0 + self.tau) # for regularization (w.r.t. current pi_theta)
return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}
@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}
Step 2: Register Your Algorithm
The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.
To simplify configuration, Trinity-RFT provides trinity.algorithm.AlgorithmType
to describe a complete algorithm and registers it in , enabling one-click configuration.
The AlgorithmType
class includes the following attributes and methods:
use_critic
: Whether to use the Critic modeluse_reference
: Whether to use the Reference modeluse_advantage
: Whether to calculate Advantage; if False, theAdvantageFn
call will be skippedcan_balance_batch
: Whether the algorithm allows automatic balancing when splitting a batch into microbatches (which permute the order of samples)schema
: The format of experience data corresponding to the algorithmdefault_config
: Gets the default configuration of the algorithm, which will override attributes with the same name inALGORITHM_TYPE
Similarly, after implementation, you need to register this module through ALGORITHM_TYPE
.
Below is the implementation for the OPMD algorithm.
Since the OPMD algorithm doesn’t need to use the Critic model, use_critic
is set to False
.
The dictionary returned by the default_config
method indicates that OPMD will use the opmd
type AdvantageFn
and PolicyLossFn
implemented in Step 1, will not apply KL Penalty on rewards, but will add a k2
type KL loss when calculating the final loss.
@ALGORITHM_TYPE.register_module("opmd")
class OPMDAlgorithm(AlgorithmType):
"""OPMD algorithm."""
use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel
@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"policy_loss_fn": "opmd",
"advantage_fn": "opmd",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}
Step 3: Use Your Algorithm
After completing all the above steps, you can use the newly registered algorithm through a YAML configuration file.
For default configurations, you just need to add the following content to your config.yaml
file:
# some other configs
algorithm:
algorithm_type: "opmd"
# some other configs
If you need to modify certain parameters, you can simply add the corresponding parameters within the algorithm
section. For example, if you need to modify repeat_times
and the initialization parameters of AdvantageFn
and PolicyLossFn
, the modified config.yaml
file would be as follows:
# some other configs
algorithm:
algorithm_type: "opmd"
repeat_times: 8
advantage_fn_args:
opmd_baseline: "logavgexp"
tau: 0.99
policy_loss_fn_args:
tau: 0.99
# some other configs
Adding New Config Entries for the Config Generator (Advanced)
Step 0: Understanding Streamlit
Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of Streamlit. This project primarily utilizes various input components from Streamlit and employs st.session_state
to store user-input parameters.
Step 1: Implement New Config Entries
To illustrate the process of creating a new parameter setting for the Config Generator page, we will use train_batch_size
as an example.
Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files:
trinity/manager/config_registry/buffer_config_manager.py
trinity/manager/config_registry/explorer_config_manager.py
trinity/manager/config_registry/model_config_manager.py
trinity/manager/config_registry/trainer_config_manager.py
In this case,
train_batch_size
should be placed in thebuffer_config_manager.py
file.Create a parameter setting function using Streamlit. The function name must follow the convention of starting with ‘set_’, and the remainder of the name becomes the config name.
Decorate the parameter setting function with the
CONFIG_GENERATORS.register_config
decorator. This decorator requires the following information:Default value of the parameter
Visibility condition (if applicable)
Additional config parameters (if needed)
Note
The CONFIG_GENERATORS.register_config
decorator automatically passes key=config_name
as an argument to the registered configuration function. Ensure that your function accepts this keyword argument.
For train_batch_size
, we will use the following settings:
Default value: 96
Visibility condition:
lambda: st.session_state["trainer_gpu_num"] > 0
Additional config:
{"_train_batch_size_per_gpu": 16}
Here’s the complete code for the train_batch_size
parameter:
@CONFIG_GENERATORS.register_config(
default_value=96,
visible=lambda: st.session_state["trainer_gpu_num"] > 0,
other_configs={"_train_batch_size_per_gpu": 16},
)
def set_train_batch_size(**kwargs):
key = kwargs.get("key")
trainer_gpu_num = st.session_state["trainer_gpu_num"]
st.session_state[key] = (
st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
)
def on_change():
st.session_state["_train_batch_size_per_gpu"] = max(
st.session_state[key] // st.session_state["trainer_gpu_num"], 1
)
st.number_input(
"Train Batch Size",
min_value=trainer_gpu_num,
step=trainer_gpu_num,
help=_str_for_train_batch_size(),
on_change=on_change,
**kwargs,
)
If the parameter requires validation, create a check function. For train_batch_size
, we need to ensure it is divisible by trainer_gpu_num
. If not, a warning should be displayed, and the parameter should be added to unfinished_fields
.
Decorate the check function with the CONFIG_GENERATORS.register_check
decorator:
@CONFIG_GENERATORS.register_check()
def check_train_batch_size(unfinished_fields: set, key: str):
if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
unfinished_fields.add(key)
st.warning(_str_for_train_batch_size())
Note
The CONFIG_GENERATORS.register_check
decorator automatically receives key=config_name
and unfinished_fields=self.unfinished_fields
as arguments. Ensure your function accepts these keyword arguments.
Step 2: Integrating New Parameters into config_manager.py
To successfully integrate new parameters into the config_manager.py
file, please adhere to the following procedure:
Parameter Categorization: Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes:
Beginner Mode: Comprises “Essential Configs” and “Important Configs” sections.
Expert Mode: Includes “Model”, “Buffer”, “Explorer and Synchronizer”, and “Trainer” sections.
Parameter Addition: Incorporate the new parameter into the relevant section using the
self.get_configs
method within theConfigManager
class.Example:
class ConfigManager: def _expert_buffer_part(self): self.get_configs("total_epochs", "train_batch_size")
YAML File Integration: Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the
generate_config
function and its associated sub-functions.Parameter Value Assignment: Utilize
st.session_state
to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.Example:
class ConfigManager: def _gen_buffer_config(self): buffer_config = { "batch_size": st.session_state["train_batch_size"], # Additional configuration parameters }
By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework.
Check Code Style
Before submitting the code, make sure it passes the code style check. Follow these steps:
# Install code style checking tools
cd <path_to_trinity_rft>
# bash
pip install -e .[dev]
# zsh
# pip install -e .\[dev\]
# Run code style checks
pre-commit run --all-files
# Commit the code after all checks pass
git commit -am "create example workflow"