🧪 实验性功能:任务选择器#
备注
该模块目前处于 实验阶段,接口可能在后续版本中发生变化。 本文档描述了系统的功能及预期使用方式。
概述#
本系统支持在探索过程中,从多个数据集/任务集(称为 tasksets)中进行智能、自适应的任务采样。它包含两个核心组件:
Selector(选择器) —— 控制每个任务集中如何选择单个样本。TasksetScheduler(任务集调度器) —— 管理哪些任务集参与当前批次的训练,并协调它们的采样过程。
二者结合,支持以下高级训练策略:
课程学习(由易到难)
多任务交替/混合训练
基于难度的采样
根据模型表现动态调整数据选择
这些能力使你能够更高效地训练模型,聚焦于信息量大或具有挑战性的样本。
模块 1:Selector —— 可定制的数据选择机制#
Selector 决定从其对应的数据集(Taskset)中选择哪些任务(样本)。除了基本的顺序或随机访问策略外,它还支持基于反馈信号(如样本难度、模型置信度、奖励等)动态调整采样行为的自适应算法。
内置的选择器类型#
选择器类型 |
说明 |
|---|---|
|
按固定顺序返回样本(0, 1, ..., N)。 |
|
每个 epoch 开始时对数据集整体打乱一次,之后按顺序遍历。 |
|
在每个 batch 中无放回地随机采样,不同 batch 之间相互独立。 |
|
根据预定义特征(如损失值、长度)对样本排序,先提供简单样本,逐步过渡到困难样本。 |
|
使用概率建模动态选择接近目标难度水平的样本。 |
你也可以实现自己的自定义选择器,以支持自适应或课程式学习。
✅ 步骤 1:实现一个自定义选择器#
要创建新的选择器,需继承 BaseSelector 类,并实现以下方法:
必须实现的方法#
方法 |
功能说明 |
|---|---|
|
返回接下来要读取的样本索引列表。 |
|
使用反馈信息(如奖励、损失)更新内部状态,用于自适应调整。 |
|
序列化当前状态,用于保存检查点。 |
|
从保存的状态字典中恢复选择器状态。 |
示例:DifficultyBasedSelector#
该选择器聚焦于模型预测表现最接近目标值的样本(例如 90% 成功率),从而挑选出“难度适中”的任务。
class DifficultyBasedSelector(BaseSelector):
def __init__(self, data_source, config: TaskSelectorConfig) -> None:
super().__init__(data_source, config)
self.logger = get_logger("difficulty_based_selector")
# 使用两个输入特征(如正确性、不确定性)构建难度估计器
self.diff_estimator = self.build_diff_estimator(
data_source.dataset, config.feature_keys, config.kwargs
)
self.current_index = 0
self.seed = config.seed
# 配置参数
self.do_sample = config.kwargs.get("do_sample", False)
self.target_reward = config.kwargs.get("target_reward", 1.0)
self.tau = config.kwargs.get("tau", 1.0)
# ... 具体实现省略
def get_indices(self, batch_size, return_extra_info=False):
# 计算得分:越接近目标奖励得分越高
sampling_scores = self.get_scores()
sampling_scores = torch.from_numpy(sampling_scores)
if self.tau == 0:
# 贪心策略:选择得分最高的 top-k 样本
selected_indices = torch.topk(sampling_scores, batch_size).indices
else:
# 随机采样:通过带温度的 softmax 进行采样
sampling_logits = sampling_scores / self.tau
sampling_logits -= sampling_logits.max() # 数值稳定性处理
sampling_probabilities = torch.softmax(sampling_logits, dim=0)
rng = torch.Generator().manual_seed(self.seed + self.current_index)
selected_indices = torch.multinomial(
sampling_probabilities,
batch_size,
replacement=False,
generator=rng,
)
self.current_index += batch_size
if return_extra_info:
# 可选:返回调试信息
extra_info = {
"indices": selected_indices.tolist(),
"scores": sampling_scores[selected_indices].tolist(),
# ... 其他元数据
}
return selected_indices, extra_info
else:
return selected_indices
def update(self, indices: List[int], values: List[float]) -> None:
# 使用观测到的奖励更新难度模型
self.diff_estimator.update(indices, values)
def state_dict(self) -> Dict:
return {"current_index": self.current_index}
def load_state_dict(self, state_dict: Dict) -> None:
self.current_index = state_dict.get("current_index", 0)
🔁 定义完类后,请在
trinity/buffer/selector/__init__.py中的default_mapping中注册,以便在配置文件中通过名称引用。
SELECTORS = Registry(
"selectors",
default_mapping={
"difficulty_based": "trinity.buffer.selector.selector.DifficultyBasedSelector",
},
)
✅ 步骤 2:实现反馈操作器(Feedback Operator)#
对于像 DifficultyBasedSelector 这样的自适应选择器,你需要提供运行时反馈(例如任务奖励)。这通过一个 Experience Operator(经验操作器) 实现,它处理 rollout 数据并计算相关指标。
📚 更多关于自定义经验处理器的内容,请参见 Operator 开发指南。
操作器必须输出一个键为 trinity.common.constants.SELECTOR_METRIC 的指标,结构如下:
{
SELECTOR_METRIC: {
0: { # taskset_id
"indices": [10, 25, 43],
"values": [0.8, 0.6, 0.9] # 例如:平均奖励值
},
1: { ... }
}
}
示例:通过率计算器(Pass Rate Calculator)#
class PassRateCalculator(ExperienceOperator):
def __init__(self, **kwargs):
pass
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
raw_metric = defaultdict(lambda: defaultdict(list))
for exp in exps:
task_index = exp.info["task_index"]
assert "taskset_id" in task_index and "index" in task_index
raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward)
metric = {}
for taskset_id, task_metrics in raw_metric.items():
indices = []
reward_means = []
for idx, rewards in task_metrics.items():
indices.append(idx)
reward_means.append(float(np.mean(rewards)))
metric[taskset_id] = {
"indices": indices,
"values": reward_means,
}
return exps, {SELECTOR_METRIC: metric}
该操作器计算每个任务的平均奖励,并将其传回对应的 Selector,用于更新难度估计。
✅ 步骤 3:更新配置文件#
完成选择器和操作器的实现后,需要在配置文件中注册它们。
将操作器加入处理流程#
data_processor:
experience_pipeline:
operators:
- name: pass_rate_calculator
为任务集配置你的选择器#
buffer:
explorer_input:
tasksets:
- name: my_taskset
storage_type: file
path: ./path/to/tasks
task_selector:
selector_type: difficulty_based
feature_keys: ["correct", "uncertainty"]
kwargs:
m: 16
lamb: 0.2
rho: 0.2
target_reward: 0.9
tau: 0.5
do_sample: true
💡 你可以定义多个任务集,每个都可以使用不同类型和配置的选择器。
模块 2:TasksetScheduler —— 多任务集协调调度#
TasksetScheduler 负责管理训练过程中不同任务集之间的交错方式。
主要特性#
支持同时加载多个任务集。
按数据集大小比例平衡采样权重。
每个 epoch 开始时打乱任务集的访问顺序。
支持课程式学习或多任务交替/混合训练。
完全可恢复断点:能精确从中断处继续训练。
与任意已注册的
Selector无缝集成。
工作原理#
在每一步训练中:
确定哪些任务集应参与当前 batch;
向各任务集的选择器请求具体的样本索引;
异步读取实际数据;
为每个任务打上
"taskset_id"标签,便于下游路由或分析。
每个 epoch 的步数由总样本数和 batch size 决定:
steps_per_epoch = total_samples // batch_size
每个 epoch 开始时,调度器会重新打乱任务集的访问顺序,以增加多样性。
总结#
通过这两个组件,你可以:
使用简单的策略(如随机或顺序采样);
利用自定义选择器设计自适应课程学习策略;
智能地融合多个数据集;
通过聚焦高价值样本提升训练效率。
将智能的 Selector 与灵活的 TasksetScheduler 结合,你将获得对模型所见内容及其出现时机的精细控制能力。