Developer Guide

This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines.

In Trinity-RFT, we decompose the RL pipeline into three main modules (Explorer, Trainer and Buffer) to facilitate customization and extension. Below is a table summarizing the modules and components that developers with different targets need to focus on.

Development Target

Core Module

Key Component

Apply existing RL algorithms to new environments.

Explorer

Workflow

Design new RL algorithms.

Trainer

Algorithm

Enhance the RL process from the data perspective.

Buffer

Data Processing Module (Coming soon)

Note

Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.


Workflows (For RL Environment Developers)

In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:


Step 0: Basic Concepts

Before starting development, it’s important to understand several core concepts:

  • Task (trinity.common.workflows.Task): Represents a data structure that can be converted into a Workflow. The content of the Task varies depending on the task type:

    • Math problems: A Task contains the problem description and the golden answer.

    • Programming scenarios: A Task includes the problem description, test cases, runtime environment, and other complex information.

  • Workflow (trinity.common.workflows.Workflow): Describes how a Task is executed. It defines the interaction flow between Agents and Environments, including logic similar to Rollout and Reward calculations in other frameworks. After execution, it generates a list of Experience. Trinity-RFT includes several built-in workflows:

    • MathWorkflow (trinity.common.workflows.MathWorkflow): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).

    • WebShopWorkflow (trinity.common.workflows.WebShopWorkflow): For webshop scenarios, it contains multi-turn interaction with environment.

    • CodeWorkflow (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.

  • Experience (trinity.common.experience.Experience): The output of running a Workflow. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, Experience includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.


Step 1: Prepare Task Dataset

The task dataset is loaded via the buffer.explorer_input.taskset configuration entry in your YAML config file. To handle differences in Task contents, Trinity-RFT provides a unified Task interface containing the following fields.

  • workflow (str): The registered name of your workflow class. You can specify it in buffer.explorer_input.taskset.default_workflow_type of your YAML config file.

  • reward_fn (Optional[str]): The registered name of your reward function. You can specify it in buffer.explorer_input.taskset.default_reward_fn_type. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.

  • raw_task (Dict): A record of raw data in Dict format. For highly customized workflow, you can directly use raw_task to initialize your Workflow instance without relying on the following fields.

  • format_args (trinity.common.config.FormatConfig): Parameters to facilitate the construction of Workflow instances. For example, the prompt_key and response_key can be used to get the prompt and response from raw_task. These settings come from the YAML configuration file and can be set in buffer.explorer_input.task_set.format.

  • rollout_args (trinity.common.config.GenerationConfig): Parameters that control the rollout process, such as temperature. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.rollout_args.

  • workflow_args (Dict): A dictionary of parameters to facilitate the construction of Workflow instances. Provides more flexibility than format_args and rollout_args by using a dictionary. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.workflow_args. Normally, you do not need to set this field.

Tip

workflow, workflow_args and raw_task provide different levels of customization.

  • workflow provides the global settings for all tasks that uses the same workflow. (Global Level)

  • workflow_args can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)

  • raw_task provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)

In the math problem scenario, the Task dataset can be a jsonl file, where each line contains JSON with question and answer fields representing the problem description and standard answer, respectively. For example:

{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...

Example configuration snippet:

# some config
buffer:
  explorer_input:
    taskset:
      default_workflow: "math_workflow"
      path: "/PATH/TO/FILE/DIR"
      format:
        prompt_key: "question"
        response_key: "answer"
      rollout_args:
        temperature: 1.0
      # some other configs

In this example, each task object’s raw_task is a Dict with two keys (question and answer). The MathWorkflow uses the prompt_key and response_key to extract the question and answer from the raw_task and use the rollout_args to generate the response.


Step 2: Implement a New Workflow

The Workflow base class interface is as follows:

class Workflow(ABC):

    def __init__(
        self,
        model: ModelWrapper,
        task: Task,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,
    ):
        self.model = model
        self.auxiliary_models = auxiliary_models

    @abstractmethod
    def run(self) -> List[Experience]:
        """Run the workflow and return a list of Experiences."""

Initialize Your Workflow

During initialization, Workflow receives the following parameters:

  • model(trinity.common.models.model.ModelWrapper): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text response_text, full sequence token ids tokens, prompt part token length prompt_length, and a list of output token logprobs logprobs).

  • task(trinity.common.workflows.Task): A single data item from the task dataset.

  • auxiliary_models(List[openai.OpenAI]):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.

Tip

You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api to true in your config file and calling model.get_openai_client() to get an openai.OpenAI instance in your workflow.

Here’s an example of initializing a simple workflow using only raw_task and rollout_args. In more complex cases, you can use the format_args for further customization.

class ExampleWorkflow(Workflow):

    def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
        super().__init__(model, task, auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args
        # Optional: If you want to use OpenAI API in your workflow
        # self.openai_client = self.model.get_openai_client()

Implementing the run method

The run method is the core of your workflow. It returns a list of Experience. Below is a simple implementation for a math workflow.

We first call the model to generate multiple response using the provided question and rollout arguments. Then we calculate the reward for each response using the calculate_reward function. Finally, we construct a list of Experience with the responses and rewards and return it.

class ExampleWorkflow(Workflow):

    # the __init__ function

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

Registering Your Workflow

Register your workflow using the WORKFLOWS.register_module decorator. Ensure the name does not conflict with existing workflows.

# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    pass

For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in trinity/common/workflows folder, e.g., trinity/common/workflows/example_workflow.py. And add the following line to trinity/common/workflows/__init__.py:

# existing import lines
from .example_workflow import ExampleWorkflow

__all__ = [
    # existing __all__ lines
    "ExampleWorkflow",
]

For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in trinity/plugins. Trinity-RFT will automatically detect and load all custom modules in this folder.

Tip

You can specify the directory where your custom modules are located by setting --plugin-dir when starting Trinity-RFT. If you don’t specify --plugin-dir, Trinity-RFT will use <Trinity_RFT_ROOT_DIR>/trinity/plugins as the default directory.

Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs. In this case, you can implement the resettable and reset methods to avoid re-initialization.

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    # some code
    # ...

    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Full Code Example

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

    def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
        super().__init__(model, task, auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Step 3: Use Your Workflow

After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type in the buffer.explorer_input.taskset domain to the newly registered Workflow name.

buffer:
  # Other fields
  explorer_input:
    taskset:
      path: /path/to/taskset
      default_workflow_type: example_workflow
      # Other fields

Now you can run your workflow in Trinity-RFT using the command:

trinity run --config <your_yaml_file>

Algorithms (For RL Algorithm Developers)

Trinity-RFT provides a standardized process for implementing new algorithms.

Step 0: Basic Concepts of Algorithm Module

In Trinity-RFT, the algorithm module is primarily responsible for extracting experience data from the Replay Buffer during the RL process and calculating the loss to update models based on this data. To avoid implementing a new Trainer class each time a new algorithm is added, we have decomposed the representative PPO algorithm process into multiple sub-modules to adapt to various algorithms.

  • Sample Strategy (trinity.algorithm.SampleStrategy): Responsible for sampling experience data from the buffer module. By customizing this module, you can implement functionalities like filtering experience data or mixed sampling from multiple data sources.

  • Advantage Fn(trinity.algorithm.AdvantageFn): Responsible for calculating the Advantage and Returns of experience data.

  • Policy Loss Fn(trinity.algorithm.PolicyLossFn): Responsible for calculating the core training loss of the policy network.

  • KL Fn(trinity.algorithm.KLFn): Responsible for calculating KL Divergence, which is generally used in two places in existing RL algorithms: Reward Penalty and Actor Loss.

  • Entropy Loss Fn(trinity.algorithm.EntropyLossFn): Responsible for calculating the entropy loss of the policy network.

We provide several implementations of above modules in trinity/algorithm.


Step 1: Implement Algorithm Components

Trinity-RFT allows developers to customize all the above modules. Developers only need to implement specific modules according to the requirements of their new algorithm. This section will provide a simple introduction using the OPMD algorithm as an example.

The main difference between OPMD and PPO algorithms lies in the calculation of Advantage and Policy Loss. Therefore, only new Advantage Fn and Policy Loss Fn modules need to be implemented.


Step 1.1: Implement AdvantageFn

Developers need to implement the trinity.algorithm.AdvantageFn interface, which mainly includes two methods:

  • __call__: Calculates advantages and returns based on input experience data, records observable metrics during the calculation process, and returns the experience data containing advantages and returns as well as a metrics dictionary. The input experience data format is verl’s DataProto.

  • default_args: Returns default initialization parameters in dictionary form, which will be used by default when users don’t specify initialization parameters in the configuration file.

After implementation, you need to register this module through trinity.algorithm.ADVANTAGE_FN. Once registered, the module can be configured in the configuration file using the registered name.

Here’s an implementation example for the OPMD algorithm’s Advantage Fn:

# trinity/algorithm/advantage_fn/opmd.py
# import some modules
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn


@ADVANTAGE_FN.register_module("opmd")
class OPMDAdvantageFn(AdvantageFn):
    """OPMD advantage computation"""

    def __init__(
        self,
        opmd_baseline: str = "mean",
        tau: float = 1.0,
    ) -> None:
        self.opmd_baseline = opmd_baseline
        self.tau = tau


    def __call__(
        self,
        exps: DataProto,
        **kwargs,
    ) -> Tuple[DataProto, Dict]:
        # calculate advantages and returns based on the exps

        # record some metrics

        return exps, metrics

    @classmethod
    def default_args(cls) -> Dict:
        return {
            "opmd_baseline": "mean",
            "tau": 1.0,
        }

Step 1.2: Implement PolicyLossFn

Developers need to implement the trinity.algorithm.PolicyLossFn interface, which is similar to AdvantageFn and includes two methods:

  • __call__: Calculates the loss based on input parameters. Unlike AdvantageFn, the input parameters here are all torch.Tensor. This interface automatically scans the parameter list of the __call__ method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from kwargs.

  • default_args: Returns default initialization parameters in dictionary form, which will be used by default when users don’t specify initialization parameters in the configuration file.

Similarly, after implementation, you need to register this module through trinity.algorithm.POLICY_LOSS_FN.

Here’s an implementation example for the OPMD algorithm’s Policy Loss Fn. Since OPMD’s Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the __call__ method:

@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
    def __init__(self, tau: float = 1.0) -> None:
        self.tau = tau

    def __call__(  # type: ignore
        self,
        logprob: torch.Tensor,
        action_mask: torch.Tensor,
        advantages: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict]:
        pg_losses = -advantages * logprob
        opmd_loss = masked_mean(pg_losses, action_mask)
        opmd_loss = opmd_loss / (1.0 + self.tau)  # for regularization (w.r.t. current pi_theta)
        return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}

    @classmethod
    def default_args(cls) -> Dict:
        return {"tau": 1.0}

Step 2: Register Your Algorithm

The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.

To simplify configuration, Trinity-RFT provides trinity.algorithm.AlgorithmType to describe a complete algorithm and registers it in , enabling one-click configuration.

The AlgorithmType class includes the following attributes and methods:

  • use_critic: Whether to use the Critic model

  • use_reference: Whether to use the Reference model

  • use_advantage: Whether to calculate Advantage; if False, the AdvantageFn call will be skipped

  • can_balance_batch: Whether the algorithm allows automatic balancing when splitting a batch into microbatches (which permute the order of samples)

  • schema: The format of experience data corresponding to the algorithm

  • default_config: Gets the default configuration of the algorithm, which will override attributes with the same name in ALGORITHM_TYPE

Similarly, after implementation, you need to register this module through ALGORITHM_TYPE.

Below is the implementation for the OPMD algorithm. Since the OPMD algorithm doesn’t need to use the Critic model, use_critic is set to False. The dictionary returned by the default_config method indicates that OPMD will use the opmd type AdvantageFn and PolicyLossFn implemented in Step 1, will not apply KL Penalty on rewards, but will add a k2 type KL loss when calculating the final loss.

@ALGORITHM_TYPE.register_module("opmd")
class OPMDAlgorithm(AlgorithmType):
    """OPMD algorithm."""

    use_critic: bool = False
    use_reference: bool = True
    use_advantage: bool = True
    can_balance_batch: bool = True
    schema: type = ExperienceModel

    @classmethod
    def default_config(cls) -> Dict:
        return {
            "repeat_times": 2,
            "sample_strategy": "warmup",
            "policy_loss_fn": "opmd",
            "advantage_fn": "opmd",
            "kl_penalty_fn": "none",
            "kl_loss_fn": "k2",
            "entropy_loss_fn": "default",
        }

Step 3: Use Your Algorithm

After completing all the above steps, you can use the newly registered algorithm through a YAML configuration file.

For default configurations, you just need to add the following content to your config.yaml file:

# some other configs
algorithm:
  algorithm_type: "opmd"
# some other configs

If you need to modify certain parameters, you can simply add the corresponding parameters within the algorithm section. For example, if you need to modify repeat_times and the initialization parameters of AdvantageFn and PolicyLossFn, the modified config.yaml file would be as follows:

# some other configs
algorithm:
  algorithm_type: "opmd"
  repeat_times: 8
  advantage_fn_args:
    opmd_baseline: "logavgexp"
    tau: 0.99
  policy_loss_fn_args:
    tau: 0.99
# some other configs

Adding New Config Entries for the Config Generator (Advanced)

Step 0: Understanding Streamlit

Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of Streamlit. This project primarily utilizes various input components from Streamlit and employs st.session_state to store user-input parameters.

Step 1: Implement New Config Entries

To illustrate the process of creating a new parameter setting for the Config Generator page, we will use train_batch_size as an example.

  1. Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files:

    • trinity/manager/config_registry/buffer_config_manager.py

    • trinity/manager/config_registry/explorer_config_manager.py

    • trinity/manager/config_registry/model_config_manager.py

    • trinity/manager/config_registry/trainer_config_manager.py

    In this case, train_batch_size should be placed in the buffer_config_manager.py file.

  2. Create a parameter setting function using Streamlit. The function name must follow the convention of starting with ‘set_’, and the remainder of the name becomes the config name.

  3. Decorate the parameter setting function with the CONFIG_GENERATORS.register_config decorator. This decorator requires the following information:

    • Default value of the parameter

    • Visibility condition (if applicable)

    • Additional config parameters (if needed)

Note

The CONFIG_GENERATORS.register_config decorator automatically passes key=config_name as an argument to the registered configuration function. Ensure that your function accepts this keyword argument.

For train_batch_size, we will use the following settings:

  • Default value: 96

  • Visibility condition: lambda: st.session_state["trainer_gpu_num"] > 0

  • Additional config: {"_train_batch_size_per_gpu": 16}

Here’s the complete code for the train_batch_size parameter:

@CONFIG_GENERATORS.register_config(
    default_value=96,
    visible=lambda: st.session_state["trainer_gpu_num"] > 0,
    other_configs={"_train_batch_size_per_gpu": 16},
)
def set_train_batch_size(**kwargs):
    key = kwargs.get("key")
    trainer_gpu_num = st.session_state["trainer_gpu_num"]
    st.session_state[key] = (
        st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
    )

    def on_change():
        st.session_state["_train_batch_size_per_gpu"] = max(
            st.session_state[key] // st.session_state["trainer_gpu_num"], 1
        )

    st.number_input(
        "Train Batch Size",
        min_value=trainer_gpu_num,
        step=trainer_gpu_num,
        help=_str_for_train_batch_size(),
        on_change=on_change,
        **kwargs,
    )

If the parameter requires validation, create a check function. For train_batch_size, we need to ensure it is divisible by trainer_gpu_num. If not, a warning should be displayed, and the parameter should be added to unfinished_fields.

Decorate the check function with the CONFIG_GENERATORS.register_check decorator:

@CONFIG_GENERATORS.register_check()
def check_train_batch_size(unfinished_fields: set, key: str):
    if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
        unfinished_fields.add(key)
        st.warning(_str_for_train_batch_size())

Note

The CONFIG_GENERATORS.register_check decorator automatically receives key=config_name and unfinished_fields=self.unfinished_fields as arguments. Ensure your function accepts these keyword arguments.

Step 2: Integrating New Parameters into config_manager.py

To successfully integrate new parameters into the config_manager.py file, please adhere to the following procedure:

  1. Parameter Categorization: Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes:

    • Beginner Mode: Comprises “Essential Configs” and “Important Configs” sections.

    • Expert Mode: Includes “Model”, “Buffer”, “Explorer and Synchronizer”, and “Trainer” sections.

  2. Parameter Addition: Incorporate the new parameter into the relevant section using the self.get_configs method within the ConfigManager class.

    Example:

    class ConfigManager:
        def _expert_buffer_part(self):
            self.get_configs("total_epochs", "train_batch_size")
    
  3. YAML File Integration: Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the generate_config function and its associated sub-functions.

  4. Parameter Value Assignment: Utilize st.session_state to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.

    Example:

    class ConfigManager:
        def _gen_buffer_config(self):
            buffer_config = {
                "batch_size": st.session_state["train_batch_size"],
                # Additional configuration parameters
            }
    

By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework.


Check Code Style

Before submitting the code, make sure it passes the code style check. Follow these steps:

# Install code style checking tools
cd <path_to_trinity_rft>
# bash
pip install -e .[dev]
# zsh
# pip install -e .\[dev\]

# Run code style checks
pre-commit run --all-files

# Commit the code after all checks pass
git commit -am "create example workflow"