Developer Guide

This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines.

Note

Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.


Creating New Workflows

Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow:


Step 0: Basic Concepts

Before starting development, it’s important to understand several core concepts:

  • Task (trinity.common.workflows.Task): Represents a data structure that can be converted into a Workflow. The content of the Task varies depending on the task type:

    • Math problems: A Task contains the problem description and the golden answer.

    • Programming scenarios: A Task includes the problem description, test cases, runtime environment, and other complex information.

  • Workflow (trinity.common.workflows.Workflow): Can be understood as the running state of a Task. It defines the interaction flow between Agents and Environments, including logic similar to Rollout and Reward calculations in other frameworks. After execution, it generates a list of Experience. Trinity-RFT includes several built-in workflows:

    • MathWorkflow (trinity.common.workflows.MathWorkflow): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).

    • WebShopWorkflow (trinity.common.workflows.WebShopWorkflow): For webshop scenarios, it contains multi-turn interaction with environment.

    • CodeWorkflow (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.

  • Experience (trinity.common.experience.Experience): The output of running a Workflow. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, Experience includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.


Step 1: Prepare Task Dataset

The task dataset is loaded via the buffer.explorer_input.taskset configuration entry in your YAML config file. To handle differences in Task contents, Trinity-RFT provides a unified Task interface containing the following fields.

  • workflow (str): The registered name of your workflow class. You can specify it in buffer.explorer_input.taskset.default_workflow_type of your YAML config file.

  • reward_fn (Optional[str]): The registered name of your reward function. You can specify it in buffer.explorer_input.taskset.default_reward_fn_type. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.

  • raw_task (Dict): An record of raw data in Dict format. For highly customized workflow, you can directly use raw_task to initialize your Workflow instance without relying on the following fields.

  • format_args (trinity.common.config.FormatConfig): Parameters to facilitate the construction of Workflow instances. For example, the prompt_key and response_key can be used to get the prompt and response from raw_task. These settings come from the YAML configuration file and can be set in buffer.explorer_input.task_set.format.

  • rollout_args (trinity.common.config.GenerationConfig): Parameters that control the rollout process, such as temperature. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.rollout_args.

In the math problem scenario, the Task dataset can be a jsonl file, where each line contains JSON with question and answer fields representing the problem description and standard answer, respectively. For example:

{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...

Example configuration snippet:

# some config
buffer:
  explorer_input:
    taskset:
      default_workflow: "math_workflow"
      path: "/PATH/TO/FILE/DIR"
      format:
        prompt_key: "question"
        response_key: "answer"
      rollout_args:
        temperature: 1.0
      # some other configs

In this example, each task object’s raw_task is a Dict with two keys (question and answer). The MathWorkflow uses the prompt_key and response_key to extract the question and answer from the raw_task and use the rollout_args to generate the response.


Step 2: Implement a New Workflow

The Workflow base class interface is as follows:

class Workflow(ABC):

    def __init__(
        self,
        model: ModelWrapper,
        task: Task,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,
    ):
        self.model = model
        self.auxiliary_models = auxiliary_models

    @abstractmethod
    def run(self) -> List[Experience]:
        """Run the workflow and return a list of Experiences."""

Initializing Your Workflow

During initialization, Workflow receives the following parameters:

  • model(trinity.common.models.model.ModelWrapper): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text response_text, full sequence token ids tokens, prompt part token length prompt_length, and a list of output token logprobs logprobs).

  • task(trinity.common.workflows.Task): A single data item from the task dataset.

  • auxiliary_models(List[openai.OpenAI]):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.

Tip

You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api to true in your config file and calling model.get_openai_client() to get an openai.OpenAI instance in your workflow.

Here’s an example of initializing a simple workflow using only raw_task and rollout_args. In more complex cases, you can use the format_args for further customization.

class ExampleWorkflow(Workflow):

    def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
        super().__init__(model, task, auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args
        # Optional: If you want to use OpenAI API in your workflow
        # self.openai_client = self.model.get_openai_client()

Implementing the run method

The run method is the core of your workflow. It returns a list of Experience. Below is a simple implementation for a math workflow.

We first call the model to generate multiple response using the provided question and rollout arguments. Then we calculate the reward for each response using the calculate_reward function. Finally, we construct a list of Experience with the responses and rewards and return it.

class ExampleWorkflow(Workflow):

    # the __init__ function

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

Registering Your Workflow

Register your workflow using the WORKFLOWS.register_module decorator. Ensure the name does not conflict with existing workflows.

# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    pass

Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs. In this case, you can implement the resettable and reset methods to avoid re-initialization.

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    # some code
    # ...

    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Full Code Example

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

    def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
        super().__init__(model, task, auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Step 3: Use Your Workflow

After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type in the buffer.explorer_input.taskset domain to the newly registered Workflow name.

buffer:
  # Other fields
  explorer_input:
    taskset:
      path: /path/to/taskset
      default_workflow_type: example_workflow
      # Other fields

Now you can run your workflow in Trinity-RFT using the command:

trinity run --config <your_yaml_file>

Adding New Config Entries for the Config Generator (Advanced)

Step 0: Understanding Streamlit

Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of Streamlit. This project primarily utilizes various input components from Streamlit and employs st.session_state to store user-input parameters.

Step 1: Implement New Config Entries

To illustrate the process of creating a new parameter setting for the Config Generator page, we will use train_batch_size as an example.

  1. Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files:

    • trinity/manager/config_registry/buffer_config_manager.py

    • trinity/manager/config_registry/explorer_config_manager.py

    • trinity/manager/config_registry/model_config_manager.py

    • trinity/manager/config_registry/trainer_config_manager.py

    In this case, train_batch_size should be placed in the buffer_config_manager.py file.

  2. Create a parameter setting function using Streamlit. The function name must follow the convention of starting with ‘set_’, and the remainder of the name becomes the config name.

  3. Decorate the parameter setting function with the CONFIG_GENERATORS.register_config decorator. This decorator requires the following information:

    • Default value of the parameter

    • Visibility condition (if applicable)

    • Additional config parameters (if needed)

Note

The CONFIG_GENERATORS.register_config decorator automatically passes key=config_name as an argument to the registered configuration function. Ensure that your function accepts this keyword argument.

For train_batch_size, we will use the following settings:

  • Default value: 96

  • Visibility condition: lambda: st.session_state["trainer_gpu_num"] > 0

  • Additional config: {"_train_batch_size_per_gpu": 16}

Here’s the complete code for the train_batch_size parameter:

@CONFIG_GENERATORS.register_config(
    default_value=96,
    visible=lambda: st.session_state["trainer_gpu_num"] > 0,
    other_configs={"_train_batch_size_per_gpu": 16},
)
def set_train_batch_size(**kwargs):
    key = kwargs.get("key")
    trainer_gpu_num = st.session_state["trainer_gpu_num"]
    st.session_state[key] = (
        st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
    )

    def on_change():
        st.session_state["_train_batch_size_per_gpu"] = max(
            st.session_state[key] // st.session_state["trainer_gpu_num"], 1
        )

    st.number_input(
        "Train Batch Size",
        min_value=trainer_gpu_num,
        step=trainer_gpu_num,
        help=_str_for_train_batch_size(),
        on_change=on_change,
        **kwargs,
    )

If the parameter requires validation, create a check function. For train_batch_size, we need to ensure it is divisible by trainer_gpu_num. If not, a warning should be displayed, and the parameter should be added to unfinished_fields.

Decorate the check function with the CONFIG_GENERATORS.register_check decorator:

@CONFIG_GENERATORS.register_check()
def check_train_batch_size(unfinished_fields: set, key: str):
    if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
        unfinished_fields.add(key)
        st.warning(_str_for_train_batch_size())

Note

The CONFIG_GENERATORS.register_check decorator automatically receives key=config_name and unfinished_fields=self.unfinished_fields as arguments. Ensure your function accepts these keyword arguments.

Step 2: Integrating New Parameters into config_manager.py

To successfully integrate new parameters into the config_manager.py file, please adhere to the following procedure:

  1. Parameter Categorization: Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes:

    • Beginner Mode: Comprises “Essential Configs” and “Important Configs” sections.

    • Expert Mode: Includes “Model”, “Buffer”, “Explorer and Synchronizer”, and “Trainer” sections.

  2. Parameter Addition: Incorporate the new parameter into the relevant section using the self.get_configs method within the ConfigManager class.

    Example:

    class ConfigManager:
        def _expert_buffer_part(self):
            self.get_configs("total_epochs", "train_batch_size")
    
  3. YAML File Integration: Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the generate_config function and its associated sub-functions.

  4. Parameter Value Assignment: Utilize st.session_state to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.

    Example:

    class ConfigManager:
        def _gen_buffer_config(self):
            buffer_config = {
                "batch_size": st.session_state["train_batch_size"],
                # Additional configuration parameters
            }
    

By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework.


Check Code Style

Before submitting the code, make sure it passes the code style check. Follow these steps:

# Install code style checking tools
cd <path_to_trinity_rft>
# bash
pip install -e .[dev]
# zsh
# pip install -e .\[dev\]

# Run code style checks
pre-commit --all-files

# Commit the code after all checks pass
git commit -am "create example workflow"