Workflow 开发指南#

在 Trinity-RFT 中,工作流(Workflow)是定义 Agent 与 Environment 之间交互的核心组件。 一个合格的工作流需要使用被训练模型完成指定任务,并从环境中获取反馈信息(奖励)。本节将会介绍如何开发一个新的工作流。


步骤 0:基本概念#

在开发之前,理解以下几个核心概念非常重要:

        flowchart LR
    A([Task]) & B([Model]) --> C[Workflow]
    C --> D([Experience])
    
  • 任务(Task) (trinity.common.workflows.Task):结构化的数据实例,包含了工作流一次运行所需的各种信息。一般情况下由训练数据集提供,数据集中的每个样本都会被转化为一个 Task 实例。Task 的内容根据任务类型而异:

    • 数学问题:包含问题和答案。

    • 编程场景:包含题目的描述、测试用例、运行环境等复杂信息。

  • 模型(Model) (trinity.common.models.model.ModelWrapper):被训练的模型,工作流内需要使用该模型来执行推理。该实例由 Trinity-RFT 自动提供,支持同步以及异步的 generate 以及 chat 等方法,同时也提供了 OpenAI API 接口,能够兼容大部分 Agent 框架。

  • 工作流(Workflow) (trinity.common.workflows.Workflow):定义了 Agent 与 Environment 的交互流程。Workflow 通过 Task 中提供的信息初始化自身,并借助 Model 来执行其中定义好的交互流程。与常规 Agent 应用不同的是,工作流内部还需要计算奖励信号(reward)以指导训练过程。Trinity-RFT 包含多个内置工作流:

    • MathWorkflow (trinity.common.workflows.MathWorkflow):用于数学场景,将问题提交给 LLM,解析 LLM 响应,并计算分数(奖励)。

    • WebShopWorkflow (trinity.common.workflows.WebShopWorkflow):用于 webshop 场景,包含与环境的多轮交互。

    • AgentScopeReActWorkflow (trinity.common.workflows.AgentScopeReActWorkflow):直接使用现有的 ReActAgent(基于 AgentScope)来解决问题。

  • 经验(Experience) (trinity.common.experience.Experience):Workflow 的运行产出。产出的数量以及内部数据格式取决于所使用的训练算法。例如,对于常见的 PPO/GRPO 算法,Experience 包含 token ID 列表、动作掩码(标识哪些 token 是由 LLM 生成的)、每个 token 的对数概率(logprobs)、奖励信号(reward)等。


步骤 1:准备任务数据集#

任务数据集通过 YAML 配置文件中的 buffer.explorer_input.taskset 配置项加载。 为处理 Task 内容的差异,Trinity-RFT 提供了一个统一的 Task 接口,包含以下字段:

  • workflow (str):你的工作流类的注册名称。你可以在 YAML 配置文件的 buffer.explorer_input.taskset.default_workflow_type 中指定。

  • reward_fn (Optional[str]):你的奖励函数的注册名称。你可以在 buffer.explorer_input.taskset.default_reward_fn_type 中指定。注意某些工作流已内置奖励计算;此时可省略该字段。

  • raw_task (Dict):原始数据的记录,以 Dict 格式存储。对于高度定制化的工作流,你可以直接使用 raw_task 初始化 Workflow 实例,而不依赖以下字段。

  • format_args (trinity.common.config.FormatConfig):便于构造 Workflow 实例的参数。例如,prompt_keyresponse_key 可用于从 raw_task 中提取 prompt 和 response。这些设置来自 YAML 配置文件,可在 buffer.explorer_input.task_set.format 中设置。

  • rollout_args (trinity.common.config.GenerationConfig):控制 rollout 过程的参数,如 temperature。该字段也来自 YAML 配置文件,可在 buffer.explorer_input.task_set.rollout_args 中设置。

  • workflow_args (Dict):用于构造 Workflow 实例的参数字典。相比 format_argsrollout_args 更灵活。该字段也来自 YAML 配置文件,可在 buffer.explorer_input.task_set.workflow_args 中设置。通常无需设置此字段。

小技巧

workflowworkflow_argsraw_task 提供了不同级别的自定义能力。

  • workflow 为使用相同工作流的所有任务提供全局设置。(全局级别)

  • workflow_args 可为每个任务数据集设置,允许使用相同工作流的不同任务数据集表现出不同行为。(数据集级别)

  • raw_task 提供对每个任务行为的自定义能力,最为灵活。(数据样本级别)

在数学问题场景中,Task 数据集可以是一个 jsonl 文件,每行包含带有 questionanswer 字段的 JSON,分别表示问题描述和标准答案。例如:

{"question": "1+1=", "answer": "2"}
{"question": "2+2=", "answer": "4"}
...

配置示例片段:

# some config
buffer:
  explorer_input:
    taskset:
      default_workflow: "math_workflow"
      path: ${oc.env:TRINITY_TASKSET_PATH}
      format:
        prompt_key: "question"
        response_key: "answer"
      rollout_args:
        temperature: 1.0
      # some other configs

在此示例中,每个任务对象的 raw_task 是一个包含两个键(questionanswer)的 DictMathWorkflow 使用 prompt_keyresponse_keyraw_task 中提取问题和答案,并使用 rollout_args 生成响应。


步骤 2:实现工作流#

Workflow 基类接口如下:

class Workflow(ABC):

    def __init__(
        self,
        *,
        task: Task,
        model: ModelWrapper,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,  # 主要用于 LLM-as-a-judge 场景,这里可以忽略
    ):
        self.task = task
        self.model = model
        self.auxiliary_models = auxiliary_models

    @abstractmethod
    def run(self) -> List[Experience]:
        """Run the workflow and return a list of Experiences."""

初始化你的工作流#

Workflow 接受以下初始化参数:

  • task(trinity.common.workflows.Task):数据集中的单个任务。

  • model(trinity.common.models.model.ModelWrapper):正在训练的模型,提供类似于 OpenAI 的接口,能够接收对话消息列表并返回 LLM 生成的内容(包括回复文本 response_text、完整序列 token id tokens、prompt 部分 token 长度 prompt_length,以及输出 token 对数概率列表 logprobs)。

  • auxiliary_models(List[openai.OpenAI]):未参与训练的辅助模型列表。所有模型均通过兼容 OpenAI 的 API 提供,主要用于 LLM-as-a-judge 场景。

以下是一个仅使用 raw_taskrollout_args 初始化简单工作流的示例。在更复杂的情况下,你可以使用 format_args 进行进一步自定义。

class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args
        # Optional: If you want to use OpenAI API in your workflow
        # self.openai_client = self.model.get_openai_client()

实现 run 方法#

run 方法是工作流的核心方法。该方法没有输入参数,返回一个 Experience 列表。 以下是一个数学工作流的简单实现。

我们首先调用模型,使用给定的问题和 rollout 参数生成答案。 然后使用 calculate_reward 函数计算答案的奖励。 最后,我们将生成的答案和奖励封装为Experience 实例并返回。

class ExampleWorkflow(Workflow):

    # the __init__ function

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            temperature=self.rollout_args.temperature,
        )
        response = responses[0]  # there is only one response
        reward: float = self.calculate_reward(response.response_text, self.answer)
        return [
            Experience(
                tokens=response.tokens,
                prompt_length=response.prompt_length,
                reward=reward,
                logprobs=response.logprobs,
            )
        ]

注册你的工作流#

为了让 Trinity-RFT 能够通过配置文件中的名称自动找到你的工作流,你需要使用 WORKFLOWS.register_module 装饰器注册。

# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    pass

对于准备贡献给 Trinity-RFT 项目的模块,你需要将上述代码放入 trinity/common/workflows 文件夹中,例如 trinity/common/workflows/example_workflow.py。并在 trinity/common/workflows/__init__.py 中添加以下行:

# existing import lines
from trinity.common.workflows.example_workflow import ExampleWorkflow

__all__ = [
    # existing __all__ lines
    "ExampleWorkflow",
]

性能调优#

以下是一些可选的性能调优方法,能够提升工作流的运行效率。当然,这些方法并非所有工作流都需要实现,具体取决于你的工作流设计。

避免重复初始化#

对于较为复杂的工作流,每次重新初始化会带来额外计算开销。 此时,你可以实现 resettablereset 方法以避免重复初始化。

resettable 方法返回一个布尔值,指示工作流是否支持轻量化重置。

reset 方法接受一个新的 Task 实例,并使用该实例更新工作流的状态。

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    # some code
    # ...

    @property
    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

批量运行推理任务#

当前流行的很多 RL 算法需要多次运行同一个任务(例如 GRPO)。该场景下一些简单任务可以直接通过模型批量推理来获得一个问题的多个回复以提升效率。 针对该情况,你可以实现 repeatable 属性以及 set_repeat_times 方法。

repeatable 属性返回一个布尔值,指示工作流是否支持在 run 方法内多次执行。

set_repeat_times 方法接受两个参数:repeat_times 指定了在 run 方法内需要执行的次数,run_id_base 是一个整数,用于标识多次运行中第一次的运行 ID,之后各次的 ID 基于此递增(该参数用于多轮交互场景,单次模型调用即可完成的任务可以忽略该项)。

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
    # some code

    @property
    def repeatable(self) -> bool:
        return True

    def set_repeat_times(self, repeat_times, run_id_base):
        self.repeat_times = repeat_times
        self.run_id_base = run_id_base

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.repeat_times,  # run multiple times in one call
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calculate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

完整代码示例#

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")
        self.rollout_args = task.rollout_args

    def calculate_reward(self, response: str, truth: str) -> float:
        if response == truth:
            return 1.0
        else:
            return 0.0

    def run(self) -> List[Experience]:
        # call the model to generate multiple responses
        responses = self.model.chat(
            [
                {
                    "role": "user",
                    "content": f"Question:\n{self.question}",
                }
            ],
            n=self.rollout_args.n,
            temperature=self.rollout_args.temperature,
        )
        experiences = []
        for response in responses:
            # calulcate reward
            reward: float = self.calculate_reward(response.response_text, self.answer)
            # construct Experience
            experiences.append(
                Experience(
                    tokens=response.tokens,
                    prompt_length=response.prompt_length,
                    reward=reward,
                    logprobs=response.logprobs,
                )
            )
        return experiences

    @property
    def resettable(self):
        return True

    def reset(self, task: Task):
        self.question = task.raw_task.get("question")
        self.answer = task.raw_task.get("answer")

    @property
    def repeatable(self) -> bool:
        return True

    def set_repeat_times(self, repeat_times, run_id_base):
        self.repeat_times = repeat_times
        self.run_id_base = run_id_base

步骤 3:使用你的工作流#

实现并注册工作流后,就可以通过将配置文件中 buffer.explorer_input.tasksetdefault_workflow_type 域设置为你的工作流名称来使用它。例如:

buffer:
  # Other fields
  explorer_input:
    taskset:
      path: /path/to/taskset
      default_workflow_type: example_workflow
      # Other fields

现在你可以使用以下命令在 Trinity-RFT 中运行你的工作流:

trinity run --config <your_yaml_file>

其他进阶特性#

async 支持#

本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以实现 asynchronous 属性并将其设置为 True,然后实现 run_async 方法,在这种情况下不再需要实现 run 方法,其余方法和属性不受影响。

@WORKFLOWS.register_module("example_workflow_async")
class ExampleWorkflowAsync(Workflow):

    @property
    def asynchronous(self):
        return True

    async def run_async(self) -> List[Experience]:
        # your async code here

    # no need to implement run() method

使用 OpenAI API#

Trinity-RFT 的 Model 提供了 OpenAI API 接口,能够降低模型推理部分的学习成本并简化工作流的实现。

为了激活 OpenAI API 服务,你需要将配置文件中 explorer.rollout_model.enable_openai_api 设置为 true 。这样就可以通过 Model 实例的 get_openai_client 方法获取 openai.OpenAI 实例。

另外,由于 OpenAI API 无法提供训练所需的各项数据,你还需要将 explorer.rollout_model.enable_history 设置为 true,让框架自动记录可用于训练的数据并转化为 Experience 列表。你可以通过 extract_experience_from_history 方法来提取这些可用于训练的数据。

# example config snippet
explorer:
  rollout_model:
    enable_openai_api: true
    enable_history: true
    # Other fields
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

    def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.model = model
        self.client: openai.OpenAI = self.model.get_openai_client()
        # or async client
        # self.client: openai.AsyncOpenAI = self.model.get_openai_async_client()
        self.agent = MyAgent(openai_client=self.client)

    def calculate_reward(self, response: str) -> float:
        # your reward calculation logic

    def run(self) -> List[Experience]:
        # run your agent
        response = self.agent.run()
        # calculate reward
        reward = self.calculate_reward(response)
        # extract experiences from history recorded in self.model
        experiences = self.model.extract_experience_from_history()
        for exp in experiences:
            exp.reward = reward
        return experiences

小技巧

  1. 当前的 OpenAI API 仅会自动记录 openai.OpenAI.chat.completions.create 以及 openai.AsyncOpenAI.chat.completions.create 方法的调用历史并转化为 Experience 结构,且不支持流式输出。

  2. 调用 chat.completions.create 时,其中的 model 字段可通过 openai_client.models.list().data[0].idopenai_client.model_path 获取。

  3. 更复杂的使用 OpenAI API 的工作流实例可参考 ReAct Agent 训练

LLM-as-a-judge 支持#

LLM-as-a-judge 是一种常见的奖励计算方法,尤其适用于开放式任务(如编程、写作等)。在这类场景下,Workflow 需要借助额外的 LLM 来评估答案质量并计算奖励信号(reward)。

为此,Trinity-RFT 提供了 Auxiliary Models(辅助模型)机制。辅助模型是一组未参与训练的模型,Workflow 可利用这些模型辅助完成任务,例如作为评判者(judge)计算奖励。

你可以在配置文件中通过 explorer.auxiliary_models 字段指定一个或多个辅助模型。例如:

explorer:
  auxiliary_models:
    - model_path: Qwen/Qwen2.5-32B-Instruct
      engine_num: 1
      tensor_parallel_size: 2
      enable_thinking: false
      max_prompt_tokens: 12288
      max_response_tokens: 12288
      max_model_len: 16384
    - model_path: Qwen/Qwen3-8B
      engine_num: 1
      tensor_parallel_size: 2
      enable_thinking: false
      max_prompt_tokens: 12288
      max_response_tokens: 12288
      max_model_len: 16384

请注意,每个辅助模型会独立占用 tensor_parallel_size * engine_num 个 GPU,请根据硬件资源合理配置。在启用辅助模型后,Trainer 可用的 GPU 数量为总 GPU 数量减去所有辅助模型及被训练的推理模型(rollout_model)所占用的 GPU 数量。

配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 openai.OpenAI 实例传递给 Workflow 初始化方法的 auxiliary_models 参数。例如:

class MyWorkflow(Workflow):
    def __init__(
        self,
        *,
        task: Task,
        model: ModelWrapper,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,
    ):
        super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
        self.judge_model = self.auxiliary_models[0]  # 使用第一个辅助模型作为评判者

    def run(self) -> List[Experience]:
        response = self.do_something()
        reward_response = self.judge_model.chat.completions.create(
            model=self.judge_model.model_path,
            messages=[
                {
                    "role": "system",
                    "content": "You are a judge. You need to give a score from 0 to 1 based on the quality of the answer.",
                },
                {
                    "role": "user",
                    "content": f"Question:\n{self.task.raw_task['question']}\nAnswer:\n{response.response_text}\nPlease give a score from 0 to 1.",
                },
            ],
            temperature=0.0,
            max_tokens=10,
        )
        # 解析奖励分数
        reward = float(reward_response.choices[0].message.content.strip())
        return [
            Experience(
                tokens=response.tokens,
                prompt_length=response.prompt_length,
                reward=reward,
                logprobs=response.logprobs,
            )
        ]

调试模式(Debug Mode)#

在 Workflow 开发过程中,频繁启动完整训练流程进行测试既耗时又低效。为此,Trinity-RFT 为开发者提供了调试模式。该模式通过预先启动推理模型,能够快速运行指定的工作流并获取结果,避免因模型加载和初始化带来的重复等待,大幅提升开发效率。流程如下:

        flowchart LR
    A[启动推理模型] --> B[调试 Workflow]
    B --> B
    

启动推理模型的命令如下:

trinity debug --config <config_file_path> --module inference_model

其中,config_file_path 为 YAML 格式的配置文件路径,格式与 trinity run 命令所用配置文件一致。配置文件中的 explorer.rollout_modelexplorer.auxiliary_models 字段会被加载,用于初始化推理模型。

模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试:

trinity debug --config <config_file_path> --module workflow --output_file <output_file_path> --plugin_dir <plugin_dir>
  • config_file_path:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。

  • output_file_path:性能分析结果输出路径。调试模式会使用 viztracer 对 Workflow 运行过程进行性能分析,并将结果保存为 HTML 文件,便于在浏览器中查看。

  • plugin_dir(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。

调试过程中,配置文件中的 buffer.explorer_input.taskset 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端,方便查看运行结果。

调试完成后,可在推理模型终端输入 Ctrl+C 以终止模型运行。