Workflow Development Guide#

In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. A qualified workflow needs to use a model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:


Step 0: Basic Concepts#

Before starting development, it’s important to understand several core concepts:

        flowchart LR
    A([Task]) & B([Model]) --> C[Workflow]
    C --> D([Experience])
    
  • Task (trinity.common.workflows.Task): Represents a data structure that contains all the information needed for a single run of the workflow. Commonly provided by the training dataset, each sample in the dataset is converted into a Task instance. The content of the Task varies depending on the task type:

    • Math problems: A Task contains the problem description and the golden answer.

    • Programming scenarios: A Task includes the problem description, test cases, runtime environment, and other complex information.

  • Model (trinity.common.models.model.ModelWrapper): The model being trained. The workflow uses this model to generate responses based on the task. Trinity-RFT will provide the model instance to initialize the workflow.

  • Workflow (trinity.common.workflows.Workflow): It defines the interaction flow between Agents and Environments. It uses the Task to initialize itself and uses the Model to generate responses. Different from general Agent Applications, a Workflow also needs to calculate rewards based on the environment’s feedback. Trinity-RFT provides several built-in workflows, including:

  • Experience (trinity.common.experience.Experience): The output of running a Workflow. The number and structure of Experience depend on the specific workflow. For example, for common PPO/GRPO algorithms, Experience includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.


Step 1: Prepare Task Dataset#

The task dataset is loaded via the buffer.explorer_input.taskset configuration entry in your YAML config file. To handle differences in Task contents, Trinity-RFT provides a unified Task interface containing the following fields.

  • workflow (str): The registered name of your workflow class. You can specify it in buffer.explorer_input.taskset.default_workflow_type of your YAML config file.

  • reward_fn (Optional[str]): The registered name of your reward function. You can specify it in buffer.explorer_input.taskset.default_reward_fn_type. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.

  • raw_task (Dict): A record of raw data in Dict format. For highly customized workflow, you can directly use raw_task to initialize your Workflow instance without relying on the following fields.

  • format_args (trinity.common.config.FormatConfig): Parameters to facilitate the construction of Workflow instances. For example, the prompt_key and response_key can be used to get the prompt and response from raw_task. These settings come from the YAML configuration file and can be set in buffer.explorer_input.task_set.format.

  • rollout_args (trinity.common.config.GenerationConfig): Parameters that control the rollout process, such as temperature. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.rollout_args.

  • workflow_args (Dict): A dictionary of parameters to facilitate the construction of Workflow instances. Provides more flexibility than format_args and rollout_args by using a dictionary. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.workflow_args. Normally, you do not need to set this field.

Tip

workflow, workflow_args and raw_task provide different levels of customization.

  • workflow provides the global settings for all tasks that uses the same workflow. (Global Level)

  • workflow_args can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)

  • raw_task provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)

In the math problem scenario, the Task dataset can be a jsonl file, where each line contains JSON with question and answer fields representing the problem description and standard answer, respectively. For example:

{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...

Example configuration snippet:

# some config
buffer:
  explorer_input:
    taskset:
      default_workflow: "math_workflow"
      path: ${oc.env:TRINITY_TASKSET_PATH}
      format:
        prompt_key: "question"
        response_key: "answer"
      rollout_args:
        temperature: 1.0
      # some other configs

In this example, each task object’s raw_task is a Dict with two keys (question and answer). The MathWorkflow uses the prompt_key and response_key to extract the question and answer from the raw_task and use the rollout_args to generate the response.


Step 2: Implement a New Workflow#

The Workflow base class interface is as follows:

class Workflow(ABC):

    def __init__(
        self,
        *,
        task: Task,
        model: ModelWrapper,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,
    ):
        self.task = task
        self.model = model
        self.auxiliary_models = auxiliary_models

    @abstractmethod
    def run(self) -> List[Experience]:
        """Run the workflow and return a list of Experiences."""

Initialize Your Workflow#

During initialization, Workflow receives the following parameters:

  • task(trinity.common.workflows.Task): A single data item from the task dataset.

  • model(trinity.common.models.model.ModelWrapper): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text response_text, full sequence token ids tokens, prompt part token length prompt_length, and a list of output token logprobs logprobs).

  • auxiliary_models(List[openai.OpenAI]):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.

Tip

You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api to true in your config file and calling model.get_openai_client() to get an openai.OpenAI instance in your workflow. And the model field when calling openai API can be obtained via openai_client.models.list().data[0].id or openai_client.model_path.

Here’s an example of initializing a simple workflow using only raw_task and rollout_args. In more complex cases, you can use the format_args for further customization.

class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args
        # Optional: If you want to use OpenAI API in your workflow
        # self.openai_client = self.model.get_openai_client()

Implementing the run method#

The run method is the core of your workflow. It returns a list of Experience. Below is a simple implementation for a math workflow.

We first call the model to generate a response using the provided question and rollout arguments. Then we calculate the reward for each response using the calculate_reward function. Finally, we construct a list of Experience with the responses and rewards and return it.

class ExampleWorkflow(Workflow):

    # the __init__ function

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            temperature=self.rollout_args.temperature,
        )
        response = responses[0]  # there is only one response
        reward: float = self.calculate_reward(response.response_text, self.answer)
        return [
            Experience(
                tokens=response.tokens,
                prompt_length=response.prompt_length,
                reward=reward,
                logprobs=response.logprobs,
            )
        ]

Registering Your Workflow#

Register your workflow using the WORKFLOWS.register_module decorator. Ensure the name does not conflict with existing workflows.

# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    pass

For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in trinity/common/workflows folder, e.g., trinity/common/workflows/example_workflow.py. And add the following line to trinity/common/workflows/__init__.py:

# existing import lines
from trinity.common.workflows.example_workflow import ExampleWorkflow

__all__ = [
    # existing __all__ lines
    "ExampleWorkflow",
]

Performance Optimization#

Avoid Re-initialization#

For heavy workflows, re-initializing every time can incurs extra computational costs. In this case, you can set the can_reset property and implement reset method to avoid re-initialization.

The can_reset is a class property that indicates whether the workflow supports resetting.

The reset method accepts a Task parameter and resets the workflow’s internal state based on the new task.

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    can_reset: bool = True

    # some code
    # ...

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Support Batch Inference#

In many popular RL algorithms, multiple runs of the same task are required (e.g., GRPO). In such scenarios, you can directly use batch inference to obtain multiple responses for a single question to improve efficiency. For this case, you can implement the can_repeat property and set_repeat_times method.

The can_repeat is a class property that indicates whether the workflow supports multiple executions within the run method.

The set_repeat_times method accepts two parameters: repeat_times specifies the number of times to execute within the run method, and run_id_base is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored).

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    can_repeat: bool = True
    # some code

    def set_repeat_times(self, repeat_times, run_id_base):
        self.repeat_times = repeat_times
        self.run_id_base = run_id_base

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.repeat_times,  # run multiple times in one call
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calculate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

Full Code Example#

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    can_reset: bool = True
    can_repeat: bool = True

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

    def set_repeat_times(self, repeat_times, run_id_base):
        self.repeat_times = repeat_times
        self.run_id_base = run_id_base

Step 3: Use Your Workflow#

After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type in the buffer.explorer_input.taskset domain to the newly registered Workflow name.

buffer:
  # Other fields
  explorer_input:
    taskset:
      path: /path/to/taskset
      default_workflow_type: example_workflow
      # Other fields

Now you can run your workflow in Trinity-RFT using the command:

trinity run --config <your_yaml_file>

Advanced Features#

Async Support#

The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set is_async to True, then implement the run_async method. In this case, you no longer need to implement the run method, and the initialization parameter auxiliary_models will also change to List[openai.AsyncOpenAI], while other methods and properties remain changed.

@WORKFLOWS.register_module("example_workflow_async")
class ExampleWorkflowAsync(Workflow):

    is_async: bool = True

    async def run_async(self) -> List[Experience]:
        # your async code here

    # no need to implement run() method

Using OpenAI API#

Trinity-RFT provides an option to use the OpenAI API for model inference. You can enable this feature by setting explorer.rollout_model.enable_openai_api to true in your configuration file. This allows you to obtain an openai.OpenAI instance via the get_openai_client method of the model instance provided by Trinity-RFT.

Additionally, since the OpenAI API does not provide all the data required for training, you also need to set explorer.rollout_model.enable_history to true. This lets the framework automatically record data that can be used for training and convert it into a list of Experience. You can extract these experiences using the extract_experience_from_history method.

# example config snippet
explorer:
  rollout_model:
    enable_openai_api: true
    enable_history: true
    # Other fields
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.model = model
        self.client: openai.OpenAI = self.model.get_openai_client()
        # or async client
        # self.client: openai.AsyncOpenAI = self.model.get_openai_async_client()
        self.agent = MyAgent(openai_client=self.client)

    def calculate_reward(self, response: str) -> float:
        # your reward calculation logic

    def run(self) -> List[Experience]:
        # run your agent
        response = self.agent.run()
        # calculate reward
        reward = self.calculate_reward(response)
        # extract experiences from history recorded in self.model
        experiences = self.model.extract_experience_from_history()
        for exp in experiences:
            exp.reward = reward
        return experiences

Tip

  1. Currently, the OpenAI API will only automatically record calls to openai.OpenAI.chat.completions.create and openai.AsyncOpenAI.chat.completions.create, and convert them into Experience objects. Streaming output is not supported.

  2. When calling chat.completions.create, the model field can be obtained via openai_client.models.list().data[0].id or openai_client.model_path.

  3. For more complex workflow examples using the OpenAI API, refer to ReAct Agent Training.

LLM-as-a-judge Support#

LLM-as-a-judge is a common reward calculation method, especially suitable for open-ended tasks (such as programming, writing, etc.). In these scenarios, the Workflow needs to leverage an additional LLM to evaluate the answer quality and compute the reward signal.

To support this, Trinity-RFT provides an Auxiliary Models mechanism. Auxiliary models are a set of models not involved in training; the Workflow can use these models to assist with tasks, such as acting as a judge to calculate rewards.

You can specify one or more auxiliary models in the configuration file via the explorer.auxiliary_models field. For example:

explorer:
  auxiliary_models:
    - model_path: Qwen/Qwen2.5-32B-Instruct
      engine_num: 1
      tensor_parallel_size: 2
      enable_thinking: false
      max_prompt_tokens: 12288
      max_response_tokens: 12288
      max_model_len: 16384
    - model_path: Qwen/Qwen3-8B
      engine_num: 1
      tensor_parallel_size: 2
      enable_thinking: false
      max_prompt_tokens: 12288
      max_response_tokens: 12288
      max_model_len: 16384

Note that each auxiliary model will independently occupy tensor_parallel_size * engine_num GPUs. Please configure according to your hardware resources. After enabling auxiliary models, the number of GPUs available to the Trainer is the total GPU count minus those occupied by all auxiliary models and the inference model being trained (rollout_model).

The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding openai.OpenAI or openai.AsyncOpenAI instances (depending on the is_async setting) to the auxiliary_models parameter of the Workflow initialization method. For example:

class MyWorkflow(Workflow):
    def __init__(
        self,
        *,
        task: Task,
        model: ModelWrapper,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,
    ):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.judge_model = self.auxiliary_models[0]  # Use the first auxiliary model as the judge

    def run(self) -> List[Experience]:
        response = self.do_something()
        reward_response = self.judge_model.chat.completions.create(
            model=self.judge_model.model_path,
            messages=[
                {
                    "role": "system",
                    "content": "You are a judge. You need to give a score from 0 to 1 based on the quality of the answer.",
                },
                {
                    "role": "user",
                    "content": f"Question:\n{self.task.raw_task['question']}\nAnswer:\n{response.response_text}\nPlease give a score from 0 to 1.",
                },
            ],
            temperature=0.0,
            max_tokens=10,
        )
        # Parse the reward score
        reward = float(reward_response.choices[0].message.content.strip())
        return [
            Experience(
                tokens=response.tokens,
                prompt_length=response.prompt_length,
                reward=reward,
                logprobs=response.logprobs,
            )
        ]

Debug Mode#

During Workflow development, repeatedly launching the full training process for testing is time-consuming and inefficient. To address this, Trinity-RFT provides a Debug Mode for developers. This mode leverages a pre-launched inference model to quickly run specified workflows and obtain results, avoiding repeated model loading and initialization delays, and significantly improving development efficiency. The process is illustrated below:

        flowchart LR
    A[Start Inference Model] --> B[Debug Workflow]
    B --> B
    

To start the inference model, use the following command:

trinity debug --config <config_file_path> --module inference_model

Here, <config_file_path> is the path to a YAML configuration file, which should follow the same format as the one used by the trinity run command. The explorer.rollout_model and explorer.auxiliary_models fields in the config will be loaded to initialize the inference model.

Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow:

trinity debug --config <config_file_path> --module workflow --output_file <output_file_path> --plugin_dir <plugin_dir>
  • <config_file_path>: Path to the YAML configuration file, usually the same as used for starting the inference model.

  • <output_file_path>: Path to save the performance profiling results. Debug Mode uses viztracer to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser.

  • <plugin_dir> (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules.

During debugging, the buffer.explorer_input.taskset field in the config will be loaded to initialize the workflow’s required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow’s return value will be automatically formatted and printed in the terminal for easy inspection.

When debugging is complete, you can terminate the inference model by pressing Ctrl+C in its terminal.