Operator 开发指南#

步骤 0:Operator 模块基本概念#

Operator 模块负责处理由 Explorer 所生成的轨迹数据(我们称之为 Experience)。它原生支持来自 Data-Juicer 的数据处理功能,也允许开发者实现自己的算子。 通过自定义数据处理算子,开发者可以实现各种数据处理功能,如数据增强、过滤和转换。你甚至可以将优势值/回报值计算实现为 Operator,如 算法 部分所示。

  • DataJuicerOperator (trinity.buffer.operators.DataJuicerOperator):封装后的 Data-Juicer 算子,使用时只需在配置文件中标明想要使用的 Data-Juicer 算子列表即可。完整的 Data-Juicer 算子列表请见 此处

  • ExperienceOperator (trinity.buffer.operators.ExperienceOperator):用于 experience 数据处理的所有数据处理算子的基类。定义了所有数据处理算子应具备的接口和通用功能。每个算子处理一批 experience 数据,并返回处理后的数据及用于日志记录的指标。

  • ExperiencePipeline (trinity.buffer.pipelines.ExperiencePipeline):管理一系列数据处理算子的 experience 数据处理流水线。它从 Explorer 获取原始 experience,通过流水线中的每个算子处理,最后将最终处理过的 experience 写入 Trainer 的输入缓冲区。

备注

除了 ExperiencePipeline,Trinity-RFT 还提供 TaskPipeline 用于任务数据处理。 当前版本中,TaskPipeline 仅支持使用 Data-Juicer 算子。详情请参见 数据处理 部分。


开发者可通过以下步骤实现并使用自己的算子。

步骤 1:实现数据处理算子#

ExperienceOperator 接口仅包含一个 process 方法。ExperiencePipeline 将调用此方法,传入 Explorer 在一次探索步骤中生成的一组 Experienceprocess 方法应返回一个元组,包含处理后的 Experience 列表和用于日志记录的指标字典。

class ExperienceOperator(ABC):

    @abstractmethod
    def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
        """Process a list of experiences and return a transformed list.

        Args:
            exps (List[Experience]): List of experiences to process, which contains
                all experiences generated by the Explorer in one explore step.
        Returns:
            Tuple[List[Experience], Dict]: A tuple containing the processed list of experiences and a dictionary of metrics.
        """

以下是一个简单数据处理算子的实现示例,该算子过滤掉奖励低于某一阈值的 experience:

from trinity.buffer.operators import ExperienceOperator
from trinity.common.experience import Experience


class RewardFilter(ExperienceOperator):

    def __init__(self, threshold: float = 0.0) -> None:
        self.threshold = threshold

    def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
        filtered_exps = [exp for exp in exps if exp.reward >= self.threshold]
        metrics = {"filtered_count": len(exps) - len(filtered_exps)}
        return filtered_exps, metrics

实现后,你需要在 trinity/buffer/operators/__init__.py 中的 default_mapping 中注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。

EXPERIENCE_OPERATORS = Registry(
    "experience_operators",
    default_mapping={
        "reward_filter": "trinity.buffer.operators.filters.reward_filter.RewardFilter",
    },
)

步骤 2:使用此算子#

完成上述步骤后,你可以通过 YAML 配置文件使用新注册的算子。

# some other configs
data_processor:
  experience_pipeline:
    operators:
      - name: "reward_filter"
        args:
          threshold: 0.1
synchronizer:
  sync_method: nccl
  sync_style: dynamic_by_explorer
  sync_interval: 2
# some other configs

小技巧

RewardFilter 会减少 experience 数量,可能导致 Trainer 无法获得足够的 experience 来启动训练流程。为避免此问题,你可以使用 Trinity-RFT 提供的 动态同步 功能 (dynamic_by_explorer)。 上述设置意味着 Explorer 每运行 2 步就会与 Trainer 同步一次,且无论 Trainer 当前完成了多少步都会继续运行。这确保了只要 Explorer 在运行,Trainer 就总能获得足够的 experience 来启动训练步骤。