API 参考

API 参考

本节介绍 Twinkle 的组件 API 和定制选项。

核心组件

组件模块描述
Datasettwinkle.dataset支持 ModelScope/HuggingFace 的数据加载
DataLoadertwinkle.dataloader支持 device mesh 的分布式数据加载
Modeltwinkle.modelTransformersModel 和 MegatronModel 封装
Samplertwinkle.sampler基于 vLLM 的推理/RL 采样
Losstwinkle.lossCrossEntropy、GRPO、GKD 损失函数
Rewardtwinkle.rewardRL 训练的奖励函数
Advantagetwinkle.advantage优势估计(GRPO、GAE)
Metrictwinkle.metric训练指标收集
Templatetwinkle.template分词模板
Preprocessortwinkle.preprocessor数据 ETL 转换

初始化

import twinkle
from twinkle import DeviceMesh, DeviceGroup

# 定义设备组
device_groups = [
    DeviceGroup(name='model', ranks=4, device_type='cuda'),
    DeviceGroup(name='sampler', ranks=4, device_type='cuda'),
]

# 定义并行拓扑
device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)

# 初始化
twinkle.initialize(
    mode='ray',           # 'local', 'ray', 或 'http'
    nproc_per_node=8,
    groups=device_groups,
    global_device_mesh=device_mesh
)

Model API

from twinkle.model import TransformersModel
from peft import LoraConfig

model = TransformersModel(
    model_id='ms://Qwen/Qwen3.5-4B',
    remote_group='model',
    device_mesh=device_mesh
)

# 添加 LoRA 适配器
lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)

# 设置优化器和调度器
model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5)

# 设置损失函数
model.set_loss('GRPOLoss', epsilon=0.2)

# 训练步骤
model.forward_backward(inputs=batch)
model.clip_grad_and_step()

# 保存检查点
model.save('checkpoint-name')

Sampler API

from twinkle.sampler import vLLMSampler
from twinkle.data_format import SamplingParams

sampler = vLLMSampler(
    model_id='ms://Qwen/Qwen3.5-4B',
    engine_args={'gpu_memory_utilization': 0.8, 'enable_lora': True},
    device_mesh=sampler_mesh,
    remote_group='sampler'
)

sampling_params = SamplingParams(max_tokens=1024, num_samples=4, logprobs=1)
responses = sampler.sample(prompts, sampling_params)

Dataset API

from twinkle.dataset import Dataset, DatasetMeta
from twinkle.dataloader import DataLoader

dataset = Dataset(dataset_meta=DatasetMeta(
    'ms://swift/self-cognition',
    data_slice=range(1000)
))
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.map(SelfCognitionProcessor('Model', 'Author'))
dataset.encode()

dataloader = DataLoader(dataset=dataset, batch_size=8)

Client API

from twinkle_client import init_twinkle_client
from twinkle_client.model import MultiLoraTransformersModel

client = init_twinkle_client(base_url='http://server:8000', api_key='key')

model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3.5-4B')
model.add_adapter_to_model('tenant_a', lora_config)
model.forward_backward(inputs=batch)
model.step()

更多详情请参阅 Cookbook 示例。

docs