Operator 开发指南#
步骤 0:Operator 模块基本概念#
Operator 模块负责处理由 Explorer 所生成的轨迹数据(我们称之为 Experience)。它原生支持来自 Data-Juicer 的数据处理功能,也允许开发者实现自己的算子。
通过自定义数据处理算子,开发者可以实现各种数据处理功能,如数据增强、过滤和转换。你甚至可以将优势值/回报值计算实现为 Operator,如 算法 部分所示。
DataJuicerOperator (
trinity.buffer.operators.DataJuicerOperator):封装后的 Data-Juicer 算子,使用时只需在配置文件中标明想要使用的 Data-Juicer 算子列表即可。完整的 Data-Juicer 算子列表请见 此处。ExperienceOperator (
trinity.buffer.operators.ExperienceOperator):用于 experience 数据处理的所有数据处理算子的基类。定义了所有数据处理算子应具备的接口和通用功能。每个算子处理一批 experience 数据,并返回处理后的数据及用于日志记录的指标。ExperiencePipeline (
trinity.buffer.pipelines.ExperiencePipeline):管理一系列数据处理算子的 experience 数据处理流水线。它从Explorer获取原始 experience,通过流水线中的每个算子处理,最后将最终处理过的 experience 写入Trainer的输入缓冲区。
备注
除了 ExperiencePipeline,Trinity-RFT 还提供 TaskPipeline 用于任务数据处理。
当前版本中,TaskPipeline 仅支持使用 Data-Juicer 算子。详情请参见 数据处理 部分。
开发者可通过以下步骤实现并使用自己的算子。
步骤 1:实现数据处理算子#
ExperienceOperator 接口仅包含一个 process 方法。ExperiencePipeline 将调用此方法,传入 Explorer 在一次探索步骤中生成的一组 Experience。process 方法应返回一个元组,包含处理后的 Experience 列表和用于日志记录的指标字典。
class ExperienceOperator(ABC):
@abstractmethod
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
"""Process a list of experiences and return a transformed list.
Args:
exps (List[Experience]): List of experiences to process, which contains
all experiences generated by the Explorer in one explore step.
Returns:
Tuple[List[Experience], Dict]: A tuple containing the processed list of experiences and a dictionary of metrics.
"""
以下是一个简单数据处理算子的实现示例,该算子过滤掉奖励低于某一阈值的 experience:
from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator
from trinity.common.experience import Experience
@EXPERIENCE_OPERATORS.register_module("reward_filter")
class RewardFilter(ExperienceOperator):
def __init__(self, threshold: float = 0.0) -> None:
self.threshold = threshold
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
filtered_exps = [exp for exp in exps if exp.reward >= self.threshold]
metrics = {"filtered_count": len(exps) - len(filtered_exps)}
return filtered_exps, metrics
实现后,你需要通过 trinity.buffer.operators.EXPERIENCE_OPERATORS 注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。
步骤 2:使用此算子#
完成上述步骤后,你可以通过 YAML 配置文件使用新注册的算子。
# some other configs
data_processor:
experience_pipeline:
operators:
- name: "reward_filter"
args:
threshold: 0.1
synchronizer:
sync_method: nccl
sync_style: dynamic_by_explorer
sync_interval: 2
# some other configs
小技巧
RewardFilter 会减少 experience 数量,可能导致 Trainer 无法获得足够的 experience 来启动训练流程。为避免此问题,你可以使用 Trinity-RFT 提供的 动态同步 功能 (dynamic_by_explorer)。
上述设置意味着 Explorer 每运行 2 步就会与 Trainer 同步一次,且无论 Trainer 当前完成了多少步都会继续运行。这确保了只要 Explorer 在运行,Trainer 就总能获得足够的 experience 来启动训练步骤。