序列并行 (SP)
序列并行沿序列维度将长序列分割到多个 GPU 上,使训练能处理超出单卡显存的序列长度。Twinkle 实现了 Ulysses 风格的序列并行,并可选地支持派生环形注意力。
概览
| 概念 | 说明 |
|---|---|
| SequenceParallelConfig | SP 配置数据类 |
| SequenceParallelStrategy | 封装 SP 生命周期的策略类 |
| SequenceParallel | 核心实现,处理填充/分割/聚合 |
配置
from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelConfig
config = SequenceParallelConfig(
enabled=True, # 启用序列并行
ulysses_size=None, # Ulysses SP 并行度(若为 None 则从 DeviceMesh 自动推导)
gather_logits=True, # 前向后聚合 logits 用于损失计算
)
配合 DeviceMesh 使用
在 DeviceMesh.from_sizes() 中设置 ulysses_size 即可激活 SP:
from twinkle.utils import DeviceMesh
# 8 卡:4 路 Ulysses SP × 2 路数据并行
device_mesh = DeviceMesh.from_sizes(
world_size=8,
dp_size=2,
ulysses_size=4,
)
工作原理
- 填充 — 输入序列被填充到可被 SP 并行度整除的长度
- 分割 — 填充后的输入沿序列维度均匀分配到各 SP rank
- 分布式注意力 — FlashAttention2 被 patch 为在注意力计算前后执行 Ulysses all-to-all 通信
- 聚合 — 前向传播后,logits 被聚合回完整序列长度用于损失计算
支持的注意力后端
| 后端 | 状态 |
|---|---|
| FlashAttention2 | 完全支持(包括打包/padding-free 序列) |
| SDPA | 支持(仅非打包批次) |
| 派生环形注意力 | 仅支持 FlashAttention2(rp_world_size > 1) |
Qwen3.5 线性注意力
SP 自动检测 Qwen3.5 GatedDeltaNet 线性注意力层,并应用 Qwen3_5GatedDeltaNetUlyssesPatch,确保混合注意力架构下序列并行的正确性。
MoE 辅助损失
对于 MoE 模型,SP 自动安装前向 hook,在计算辅助损失前跨 SP rank 聚合路由 logits,确保负载均衡信号的正确性。
关键约束
num_key_value_heads必须能被ulysses_size整除(Ulysses 模式),否则回退到环形注意力- 打包/padding-free 批次需要 FlashAttention2
- 派生环形注意力要求
batch_size == 1(打包格式) torch.distributed必须已初始化