Workflow Development Guide#

In Trinity-RFT, workflows are the core components that define the interaction between Agents and Environments. A qualified workflow needs to use the trained model to complete the specified task and obtain feedback information (reward) from the environment. Below are the steps to create a new workflow:


Step 0: Basic Concepts#

Before starting development, it’s important to understand several core concepts:

  • Task (trinity.common.workflows.Task): Represents a data structure that can be converted into a Workflow. The content of the Task varies depending on the task type:

    • Math problems: A Task contains the problem description and the golden answer.

    • Programming scenarios: A Task includes the problem description, test cases, runtime environment, and other complex information.

  • Workflow (trinity.common.workflows.Workflow): Describes how a Task is executed. It defines the interaction flow between Agents and Environments, including logic similar to Rollout and Reward calculations in other frameworks. After execution, it generates a list of Experience. Trinity-RFT includes several built-in workflows:

  • Experience (trinity.common.experience.Experience): The output of running a Workflow. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, Experience includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc.


Step 1: Prepare Task Dataset#

The task dataset is loaded via the buffer.explorer_input.taskset configuration entry in your YAML config file. To handle differences in Task contents, Trinity-RFT provides a unified Task interface containing the following fields.

  • workflow (str): The registered name of your workflow class. You can specify it in buffer.explorer_input.taskset.default_workflow_type of your YAML config file.

  • reward_fn (Optional[str]): The registered name of your reward function. You can specify it in buffer.explorer_input.taskset.default_reward_fn_type. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field.

  • raw_task (Dict): A record of raw data in Dict format. For highly customized workflow, you can directly use raw_task to initialize your Workflow instance without relying on the following fields.

  • format_args (trinity.common.config.FormatConfig): Parameters to facilitate the construction of Workflow instances. For example, the prompt_key and response_key can be used to get the prompt and response from raw_task. These settings come from the YAML configuration file and can be set in buffer.explorer_input.task_set.format.

  • rollout_args (trinity.common.config.GenerationConfig): Parameters that control the rollout process, such as temperature. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.rollout_args.

  • workflow_args (Dict): A dictionary of parameters to facilitate the construction of Workflow instances. Provides more flexibility than format_args and rollout_args by using a dictionary. This field also comes from the YAML configuration file and can be set in buffer.explorer_input.task_set.workflow_args. Normally, you do not need to set this field.

Tip

workflow, workflow_args and raw_task provide different levels of customization.

  • workflow provides the global settings for all tasks that uses the same workflow. (Global Level)

  • workflow_args can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)

  • raw_task provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)

In the math problem scenario, the Task dataset can be a jsonl file, where each line contains JSON with question and answer fields representing the problem description and standard answer, respectively. For example:

{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...

Example configuration snippet:

# some config
buffer:
  explorer_input:
    taskset:
      default_workflow: "math_workflow"
      path: ${oc.env:TRINITY_TASKSET_PATH}
      format:
        prompt_key: "question"
        response_key: "answer"
      rollout_args:
        temperature: 1.0
      # some other configs

In this example, each task object’s raw_task is a Dict with two keys (question and answer). The MathWorkflow uses the prompt_key and response_key to extract the question and answer from the raw_task and use the rollout_args to generate the response.


Step 2: Implement a New Workflow#

The Workflow base class interface is as follows:

class Workflow(ABC):

    def __init__(
        self,
        *,
        task: Task,
        model: ModelWrapper,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,
    ):
        self.task = task
        self.model = model
        self.auxiliary_models = auxiliary_models

    @abstractmethod
    def run(self) -> List[Experience]:
        """Run the workflow and return a list of Experiences."""

Initialize Your Workflow#

During initialization, Workflow receives the following parameters:

  • task(trinity.common.workflows.Task): A single data item from the task dataset.

  • model(trinity.common.models.model.ModelWrapper): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text response_text, full sequence token ids tokens, prompt part token length prompt_length, and a list of output token logprobs logprobs).

  • auxiliary_models(List[openai.OpenAI]):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.

Tip

You can switch to using the OpenAI API by setting explorer.rollout_model.enable_openai_api to true in your config file and calling model.get_openai_client() to get an openai.OpenAI instance in your workflow. And the model field when calling openai API can be obtained via openai_client.models.list().data[0].id or openai_client.model_path.

Here’s an example of initializing a simple workflow using only raw_task and rollout_args. In more complex cases, you can use the format_args for further customization.

class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args
        # Optional: If you want to use OpenAI API in your workflow
        # self.openai_client = self.model.get_openai_client()

Implementing the run method#

The run method is the core of your workflow. It returns a list of Experience. Below is a simple implementation for a math workflow.

We first call the model to generate multiple response using the provided question and rollout arguments. Then we calculate the reward for each response using the calculate_reward function. Finally, we construct a list of Experience with the responses and rewards and return it.

class ExampleWorkflow(Workflow):

    # the __init__ function

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

Registering Your Workflow#

Register your workflow using the WORKFLOWS.register_module decorator. Ensure the name does not conflict with existing workflows.

# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    pass

For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in trinity/common/workflows folder, e.g., trinity/common/workflows/example_workflow.py. And add the following line to trinity/common/workflows/__init__.py:

# existing import lines
from trinity.common.workflows.example_workflow import ExampleWorkflow

__all__ = [
    # existing __all__ lines
    "ExampleWorkflow",
]

Avoid Re-initialization#

For heavy workflows, re-initializing every time can incurs extra computational costs. In this case, you can implement the resettable and reset methods to avoid re-initialization.

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    # some code
    # ...

    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Full Code Example#

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

Step 3: Use Your Workflow#

After implementing and registering your workflow, you need to update the configuration file to set the default_workflow_type in the buffer.explorer_input.taskset domain to the newly registered Workflow name.

buffer:
  # Other fields
  explorer_input:
    taskset:
      path: /path/to/taskset
      default_workflow_type: example_workflow
      # Other fields

Now you can run your workflow in Trinity-RFT using the command:

trinity run --config <your_yaml_file>