Developer Guide

This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines.

In Trinity-RFT, we decompose the RL pipeline into three main modules (Explorer, Trainer and Buffer) to facilitate customization and extension. Below is a table summarizing the modules and components that developers with different targets need to focus on.

Development Target

Core Module

Key Component

Apply existing RL algorithms to new environments.

Explorer

Workflow

Design new RL algorithms.

Trainer

Algorithm

Enhance the RL process from the data perspective.

Buffer

Operator

Note

Trinity-RFT is under active development, and the following interfaces may change. Please refer to the latest code when using this guide.


Workflows (For RL Environment Developers)

In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:


Step 0: Basic Concepts

Before starting development, it’s important to understand several core concepts:

  • Task (trinity.common.workflows.Task): Represents a data structure that can be converted into a Workflow. The content of the Task varies depending on the task type:

    • Math problems: A Task contains the problem description and the golden answer.

    • Programming scenarios: A Task includes the problem description, test cases, runtime environment, and other complex information.

  • Workflow (trinity.common.workflows.Workflow): Describes how a Task is executed. It defines the interaction flow between Agents and Environments, including logic similar to Rollout and Reward calculations in other frameworks. After execution, it generates a list of Experience. Trinity-RFT includes several built-in workflows:

    • MathWorkflow (trinity.common.workflows.MathWorkflow): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards).

    • WebShopWorkflow (trinity.common.workflows.WebShopWorkflow): For webshop scenarios, it contains multi-turn interaction with environment.

    • CodeWorkflow (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results.

  • Experience (trinity.common.experience.Experience): The output of running a Workflow. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, Experience includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.


Step 1: Prepare Task Dataset

The task dataset is loaded via the buffer.explorer_input.taskset configuration entry in your YAML config file. To handle differences in Task contents, Trinity-RFT provides a unified Task interface containing the following fields.

  • workflow (str): The registered name of your workflow class. You can specify it in buffer.explorer_input.taskset.default_workflow_type of your YAML config file.

  • reward_fn (Optional[str]): The registered name of your reward function. You can specify it in buffer.explorer_input.taskset.default_reward_fn_type. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.

  • raw_task (Dict): A record of raw data in Dict format. For highly customized workflow, you can directly use raw_task to initialize your Workflow instance without relying on the following fields.

  • format_args (trinity.common.config.FormatConfig): Parameters to facilitate the construction of Workflow instances. For example, the prompt_key and response_key can be used to get the prompt and response from raw_task. These settings come from the YAML configuration file and can be set in buffer.explorer_input.task_set.format.

  • rollout_args (trinity.common.config.GenerationConfig): Parameters that control the rollout process, such as temperature. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.rollout_args.

  • workflow_args (Dict): A dictionary of parameters to facilitate the construction of Workflow instances. Provides more flexibility than format_args and rollout_args by using a dictionary. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.workflow_args. Normally, you do not need to set this field.

Tip

workflow, workflow_args and raw_task provide different levels of customization.

  • workflow provides the global settings for all tasks that uses the same workflow. (Global Level)

  • workflow_args can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)

  • raw_task provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)

In the math problem scenario, the Task dataset can be a jsonl file, where each line contains JSON with question and answer fields representing the problem description and standard answer, respectively. For example:

{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...

Example configuration snippet:

# some config
buffer:
  explorer_input:
    taskset:
      default_workflow: "math_workflow"
      path: "/PATH/TO/FILE/DIR"
      format:
        prompt_key: "question"
        response_key: "answer"
      rollout_args:
        temperature: 1.0
      # some other configs

In this example, each task object’s raw_task is a Dict with two keys (question and answer). The MathWorkflow uses the prompt_key and response_key to extract the question and answer from the raw_task and use the rollout_args to generate the response.


Step 2: Implement a New Workflow

The Workflow base class interface is as follows:

class Workflow(ABC):

    def __init__(
        self,
        *,
        task: Task,
        model: ModelWrapper,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,
    ):
        self.task = task
        self.model = model
        self.auxiliary_models = auxiliary_models

    @abstractmethod
    def run(self) -> List[Experience]:
        """Run the workflow and return a list of Experiences."""

Initialize Your Workflow

During initialization, Workflow receives the following parameters:

  • task(trinity.common.workflows.Task): A single data item from the task dataset.

  • model(trinity.common.models.model.ModelWrapper): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text response_text, full sequence token ids tokens, prompt part token length prompt_length, and a list of output token logprobs logprobs).

  • auxiliary_models(List[openai.OpenAI]):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.

Tip

You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api to true in your config file and calling model.get_openai_client() to get an openai.OpenAI instance in your workflow. And the model field when calling openai API can be obtained via openai_client.models.list().data[0].id or openai_client.model_path.

Here’s an example of initializing a simple workflow using only raw_task and rollout_args. In more complex cases, you can use the format_args for further customization.

class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args
        # Optional: If you want to use OpenAI API in your workflow
        # self.openai_client = self.model.get_openai_client()

Implementing the run method

The run method is the core of your workflow. It returns a list of Experience. Below is a simple implementation for a math workflow.

We first call the model to generate multiple response using the provided question and rollout arguments. Then we calculate the reward for each response using the calculate_reward function. Finally, we construct a list of Experience with the responses and rewards and return it.

class ExampleWorkflow(Workflow):

    # the __init__ function

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

Registering Your Workflow

Register your workflow using the WORKFLOWS.register_module decorator. Ensure the name does not conflict with existing workflows.

# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    pass

For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in trinity/common/workflows folder, e.g., trinity/common/workflows/example_workflow.py. And add the following line to trinity/common/workflows/__init__.py:

# existing import lines
from .example_workflow import ExampleWorkflow

__all__ = [
    # existing __all__ lines
    "ExampleWorkflow",
]

Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs. In this case, you can implement the resettable and reset methods to avoid re-initialization.

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    # some code
    # ...

    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Full Code Example

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Step 3: Use Your Workflow

After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type in the buffer.explorer_input.taskset domain to the newly registered Workflow name.

buffer:
  # Other fields
  explorer_input:
    taskset:
      path: /path/to/taskset
      default_workflow_type: example_workflow
      # Other fields

Now you can run your workflow in Trinity-RFT using the command:

trinity run --config <your_yaml_file>

Algorithms (For RL Algorithm Developers)

Trinity-RFT provides a standardized process for implementing new algorithms.

Step 0: Basic Concepts of Algorithm Module

In Trinity-RFT, the algorithm module is primarily responsible for extracting experience data from the Replay Buffer during the RL process and calculating the loss to update models based on this data. To avoid implementing a new Trainer class each time a new algorithm is added, we have decomposed the representative PPO algorithm process into multiple sub-modules to adapt to various algorithms.

  • Sample Strategy (trinity.algorithm.SampleStrategy): Responsible for sampling experience data from the buffer module. By customizing this module, you can implement functionalities like filtering experience data or mixed sampling from multiple data sources.

  • Advantage Fn(trinity.algorithm.AdvantageFn): Responsible for calculating the Advantage and Returns of experience data.

  • Policy Loss Fn(trinity.algorithm.PolicyLossFn): Responsible for calculating the core training loss of the policy network.

  • KL Fn(trinity.algorithm.KLFn): Responsible for calculating KL Divergence, which is generally used in two places in existing RL algorithms: Reward Penalty and Actor Loss.

  • Entropy Loss Fn(trinity.algorithm.EntropyLossFn): Responsible for calculating the entropy loss of the policy network.

We provide several implementations of above modules in trinity/algorithm.


Step 1: Implement Algorithm Components

Trinity-RFT allows developers to customize all the above modules. Developers only need to implement specific modules according to the requirements of their new algorithm. This section will provide a simple introduction using the OPMD algorithm as an example.

The main difference between OPMD and PPO algorithms lies in the calculation of Advantage and Policy Loss. OPMD relies on a group-based advantage calculation and does not use the Critic model. To implement OPMD, developers need to implement advantage calculation in AdvantageFn and policy loss calculation in PolicyLossFn.


Step 1.1: Implement AdvantageFn

The trinity.algorithm.AdvantageFn interface includes three methods:

  • __call__: The main entrance for advantage calculation. It receives a list of experiences (trinity.common.experience.Experience) and returns a list of experiences with calculated advantages and returns, along with a metrics dictionary for logging.

  • default_args: A class method that returns the default initialization parameters in dictionary form. It will be used by default when users don’t specify initialization parameters in the configuration file.

  • compute_in_trainer: This class method indicates whether to compute advantages in the Trainer. If it returns False, the AdvantageFn will be called in the experience data processing pipeline.

For convenience, Trinity-RFT provides an abstract class trinity.algorithm.advantage_fn.GroupAdvantage that implements the __call__ method for group-based advantage calculation, you can focus on how to group the experiences and calculate advantages on grouped experiences with the following two methods:

  • group_experiences: This method groups a experiences generated in a step into multiple sub-groups.

  • calculate_group_advantage: This method calculates the advantage for each group of experiences.

Here’s an implementation example for the OPMD algorithm’s advantage function:

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, GroupAdvantage

@ADVANTAGE_FN.register_module("opmd")
class OPMDGroupAdvantage(GroupAdvantage):
    """OPMD Group Advantage computation"""

    def __init__(self, opmd_baseline: str = "mean", tau: float = 1.0, **kwargs) -> None:
        super().__init__(**kwargs)
        self.opmd_baseline = opmd_baseline
        self.tau = tau

    def group_experiences(self, exps):
        return group_by(exps, id_type="task")

    def calculate_group_advantage(
        self, group_id: str, exps: List[Experience]
    ) -> Tuple[List[Experience], Dict]:
        with torch.no_grad():
            if len(exps) == 1:
                group_baseline = torch.tensor(0.0)
            else:
                group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
                if self.opmd_baseline == "mean":
                    group_baseline = torch.mean(group_rewards)
                else:
                    group_baseline = self.tau * (
                        torch.logsumexp(group_rewards / self.tau, dim=-1)
                        - torch.log(torch.tensor(len(exps)))
                    )
            for exp in exps:
                score = exp.reward - group_baseline
                exp.advantages = score * exp.action_mask
                exp.returns = exp.advantages.clone()
            metrics = {
                "group_baseline": group_baseline.item(),
            }
        return exps, metrics

    @classmethod
    def default_args(cls) -> dict:
        return {"opmd_baseline": "mean", "tau": 1.0}

After implementation, you need to register this module through trinity.algorithm.ADVANTAGE_FN. Once registered, the module can be configured in the configuration file using the registered name.

Step 1.3: Implement PolicyLossFn

Developers need to implement the trinity.algorithm.PolicyLossFn interface, which is similar to AdvantageFn and includes two methods:

  • __call__: Calculates the loss based on input parameters. Unlike AdvantageFn, the input parameters here are all torch.Tensor. This interface automatically scans the parameter list of the __call__ method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from kwargs.

  • default_args: Returns default initialization parameters in dictionary form, which will be used by default when users don’t specify initialization parameters in the configuration file.

Similarly, after implementation, you need to register this module through trinity.algorithm.POLICY_LOSS_FN.

Here’s an implementation example for the OPMD algorithm’s Policy Loss Fn. Since OPMD’s Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the __call__ method:

@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
    def __init__(self, tau: float = 1.0) -> None:
        self.tau = tau

    def __call__(  # type: ignore
        self,
        logprob: torch.Tensor,
        action_mask: torch.Tensor,
        advantages: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict]:
        pg_losses = -advantages * logprob
        opmd_loss = masked_mean(pg_losses, action_mask)
        opmd_loss = opmd_loss / (1.0 + self.tau)  # for regularization (w.r.t. current pi_theta)
        return opmd_loss, {"opmd_loss": opmd_loss.detach().item()}

    @classmethod
    def default_args(cls) -> Dict:
        return {"tau": 1.0}

Step 2: Register Your Algorithm

The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.

To simplify configuration, Trinity-RFT provides trinity.algorithm.AlgorithmType to describe a complete algorithm and registers it in , enabling one-click configuration.

The AlgorithmType class includes the following attributes and methods:

  • use_critic: Whether to use the Critic model

  • use_reference: Whether to use the Reference model

  • compute_advantage_in_trainer: Whether to calculate Advantages in Trainer; if False, the AdvantageFn call in trainer will be skipped

  • can_balance_batch: Whether the algorithm allows automatic balancing when splitting a batch into microbatches (which permute the order of samples)

  • schema: The format of experience data corresponding to the algorithm

  • default_config: Gets the default configuration of the algorithm, which will override attributes with the same name in ALGORITHM_TYPE

Similarly, after implementation, you need to register this module through ALGORITHM_TYPE.

Below is the implementation for the OPMD algorithm. Since the OPMD algorithm doesn’t need to use the Critic model, use_critic is set to False. The dictionary returned by the default_config method indicates that OPMD will use the opmd type AdvantageFn and PolicyLossFn implemented in Step 1, will not apply KL Penalty on rewards, but will add a k2 type KL loss when calculating the final loss.

@ALGORITHM_TYPE.register_module("opmd")
class OPMDAlgorithm(AlgorithmType):
    """OPMD algorithm."""

    use_critic: bool = False
    use_reference: bool = True
    compute_advantage_in_trainer: bool = False
    can_balance_batch: bool = True
    schema: type = ExperienceModel

    @classmethod
    def default_config(cls) -> Dict:
        return {
            "repeat_times": 2,
            "advantage_fn": "opmd",
            "sample_strategy": "warmup",
            "policy_loss_fn": "opmd",
            "kl_penalty_fn": "none",
            "kl_loss_fn": "k2",
            "entropy_loss_fn": "default",
        }

Step 3: Use Your Algorithm

After completing all the above steps, you can use the newly registered algorithm through a YAML configuration file.

For default configurations, you just need to add the following content to your config.yaml file:

# some other configs
algorithm:
  algorithm_type: "opmd"
# some other configs

If you need to modify certain parameters, you can simply add the corresponding parameters within the algorithm section. For example, if you need to modify repeat_times and the initialization parameters of AdvantageFn and PolicyLossFn, the modified config.yaml file would be as follows:

# some other configs
algorithm:
  algorithm_type: "opmd"
  repeat_times: 8
  advantage_fn_args:
    opmd_baseline: "logavgexp"
    tau: 0.99
  policy_loss_fn_args:
    tau: 0.99
# some other configs

Operators (For Data Developers)

Step 0: Basic Concepts of Operator Module

In Trinity-RFT, the operator module is responsible for processing experience data in the buffer module. It supports existing data processing capabilities from Data-Juicer naturally, and allows developers to implement their own operators as well. By customizing operators, developers can implement various data processing functionalities, such as data augmentation, filtering, and transformation. You can even implement advantages/returns calculation as operators, as shown in Algorithms section.

  • DataJuicerOperator (trinity.data.operators.DataJuicerOperator): The operator that wraps the data processing operators from Data-Juicer. It provides a simple interface for developers to list the Data-Juicer operators they want to use. The full list of Data-Juicer operators can be found here.

  • ExperienceOperator (trinity.data.operators.ExperienceOperator): The base class for all operators used in experience data processing. It defines the interface and common functionalities that all operators should have. Each operator processes a batch of experience data and returns the processed data with metrics for logging.

  • ExperiencePipeline (trinity.data.pipelines.ExperiencePipeline): The experience data processing pipeline that manages a sequence of operators. It takes raw experiences from the Explorer, passes them through each operator in the pipeline, and writes the final processed experiences into the input buffer of the Trainer.

Note

Except for ExperiencePipeline, Trinity-RFT also provides TaskPipeline for task data processing. In the current version, the TaskPipeline only supports using Data-Juicer operators. Please see this section for details.


Developers can implement and use their own operators by following the steps below.

Step 1: Implement Operator

The ExperienceOperator interface includes only one process method. The ExperiencePipeline will call this method with a list of Experience generated by the Explorer in one explore step. The process method should return a tuple containing the processed list of Experience and a dictionary of metrics for logging.

class ExperienceOperator(ABC):

    @abstractmethod
    def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
        """Process a list of experiences and return a transformed list.

        Args:
            exps (List[Experience]): List of experiences to process, which contains
                all experiences generated by the Explorer in one explore step.
        Returns:
            Tuple[List[Experience], Dict]: A tuple containing the processed list of experiences and a dictionary of metrics.
        """

Here is an implementation of a simple operator that filters out experiences with rewards below a certain threshold:

from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator
from trinity.common.experience import Experience


@EXPERIENCE_OPERATORS.register_module("reward_filter")
class RewardFilter(ExperienceOperator):

    def __init__(self, threshold: float = 0.0) -> None:
        self.threshold = threshold

    def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
        filtered_exps = [exp for exp in exps if exp.reward >= self.threshold]
        metrics = {"filtered_count": len(exps) - len(filtered_exps)}
        return filtered_exps, metrics

After implementation, you need to register this module through trinity.data.operators.EXPERIENCE_OPERATORS. Once registered, the module can be configured in the configuration file using the registered name.

Step 2: Use Your Operator

After completing the above steps, you can use the newly registered operator through a YAML configuration file.

# some other configs
data_processor:
  experience_pipeline:
    operators:
      - name: "reward_filter"
        args:
          threshold: 0.1
synchronizer:
  sync_method: nccl
  sync_style: dynamic_by_explorer
  sync_interval: 2
# some other configs

Tip

The RewardFilter reduces the number of experiences, which may cause the trainer can’t get enough experiences to start a training step. To avoid the issue, you can use the advanced Dynamic Synchronization feature provided by Trinity-RFT as shown in the above configuration file. The above setting means that the Explorer will sync with the Trainer every 2 steps and will continue running regardless of how many steps the Trainer has completed. This ensures that the Trainer can always get enough experiences to start a training step as long as the Explorer is running.


Adding New Config Entries for the Config Generator (Advanced)

Step 0: Understanding Streamlit

Before adding new parameters to the Config Generator page, it is essential to familiarize yourself with the relevant API and mechanisms of Streamlit. This project primarily utilizes various input components from Streamlit and employs st.session_state to store user-input parameters.

Step 1: Implement New Config Entries

To illustrate the process of creating a new parameter setting for the Config Generator page, we will use train_batch_size as an example.

  1. Determine the appropriate scope for the parameter. Currently, parameters are categorized into four files:

    • trinity/manager/config_registry/buffer_config_manager.py

    • trinity/manager/config_registry/explorer_config_manager.py

    • trinity/manager/config_registry/model_config_manager.py

    • trinity/manager/config_registry/trainer_config_manager.py

    In this case, train_batch_size should be placed in the buffer_config_manager.py file.

  2. Create a parameter setting function using Streamlit. The function name must follow the convention of starting with ‘set_’, and the remainder of the name becomes the config name.

  3. Decorate the parameter setting function with the CONFIG_GENERATORS.register_config decorator. This decorator requires the following information:

    • Default value of the parameter

    • Visibility condition (if applicable)

    • Additional config parameters (if needed)

Note

The CONFIG_GENERATORS.register_config decorator automatically passes key=config_name as an argument to the registered configuration function. Ensure that your function accepts this keyword argument.

For train_batch_size, we will use the following settings:

  • Default value: 96

  • Visibility condition: lambda: st.session_state["trainer_gpu_num"] > 0

  • Additional config: {"_train_batch_size_per_gpu": 16}

Here’s the complete code for the train_batch_size parameter:

@CONFIG_GENERATORS.register_config(
    default_value=96,
    visible=lambda: st.session_state["trainer_gpu_num"] > 0,
    other_configs={"_train_batch_size_per_gpu": 16},
)
def set_train_batch_size(**kwargs):
    key = kwargs.get("key")
    trainer_gpu_num = st.session_state["trainer_gpu_num"]
    st.session_state[key] = (
        st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"]
    )

    def on_change():
        st.session_state["_train_batch_size_per_gpu"] = max(
            st.session_state[key] // st.session_state["trainer_gpu_num"], 1
        )

    st.number_input(
        "Train Batch Size",
        min_value=trainer_gpu_num,
        step=trainer_gpu_num,
        help=_str_for_train_batch_size(),
        on_change=on_change,
        **kwargs,
    )

If the parameter requires validation, create a check function. For train_batch_size, we need to ensure it is divisible by trainer_gpu_num. If not, a warning should be displayed, and the parameter should be added to unfinished_fields.

Decorate the check function with the CONFIG_GENERATORS.register_check decorator:

@CONFIG_GENERATORS.register_check()
def check_train_batch_size(unfinished_fields: set, key: str):
    if st.session_state[key] % st.session_state["trainer_gpu_num"] != 0:
        unfinished_fields.add(key)
        st.warning(_str_for_train_batch_size())

Note

The CONFIG_GENERATORS.register_check decorator automatically receives key=config_name and unfinished_fields=self.unfinished_fields as arguments. Ensure your function accepts these keyword arguments.

Step 2: Integrating New Parameters into config_manager.py

To successfully integrate new parameters into the config_manager.py file, please adhere to the following procedure:

  1. Parameter Categorization: Determine the appropriate section for the new parameter based on its functionality. The config generator page is structured into two primary modes:

    • Beginner Mode: Comprises “Essential Configs” and “Important Configs” sections.

    • Expert Mode: Includes “Model”, “Buffer”, “Explorer and Synchronizer”, and “Trainer” sections.

  2. Parameter Addition: Incorporate the new parameter into the relevant section using the self.get_configs method within the ConfigManager class.

    Example:

    class ConfigManager:
        def _expert_buffer_part(self):
            self.get_configs("total_epochs", "train_batch_size")
    
  3. YAML File Integration: Locate the appropriate position for the new parameter within the YAML file structure. This should be done in the generate_config function and its associated sub-functions.

  4. Parameter Value Assignment: Utilize st.session_state to retrieve the parameter value from the config generator page and assign it to the corresponding field in the YAML.

    Example:

    class ConfigManager:
        def _gen_buffer_config(self):
            buffer_config = {
                "batch_size": st.session_state["train_batch_size"],
                # Additional configuration parameters
            }
    

By meticulously following these steps, you can ensure that new parameters are successfully added to the Config Generator page and properly integrated into the configuration system. This process maintains the integrity and functionality of the configuration management framework.


Contributing your Code

For modules that are prepared to be contributed to the Trinity-RFT project, please follow the steps below:

  1. Implement your code in the appropriate directory, such as trinity/common/workflows for workflows, trinity/algorithm for algorithms, trinity/buffer/operators for operators.

  2. Register your module in the corresponding __init__.py file of the directory.

  3. Add tests for your module in the tests directory, following the naming conventions and structure of existing tests.

  4. Before submitting the code, make sure it passes the code style check with pre-commit run --all-files.

  5. Submit a pull request to the Trinity-RFT repository, including a clear description of your changes.

Tip

For modules that only used for local testing or not intended for contribution, you can place them in the trinity/plugins directory. Trinity-RFT will automatically load all modules in this directory, and you can use those modules without adding them to the __init__.py file. You can specify another directory by setting the --plugin-dir option when running Trinity-RFT, e.g., trinity run --config /path/to/your/config --plugin-dir /path/to/your/plugins.