多轮 RL (OpenEnv)
·
1 分钟阅读时长
多轮 GRPO + 交互式环境 — Agent 通过 tool call 与环境交互,从 episode reward 中学习。
import twinkle
from twinkle import DeviceMesh, DeviceGroup
from twinkle.advantage import GRPOAdvantage
from twinkle.data_format import SamplingParams
from twinkle.sampler import vLLMSampler
from twinkle.model import TransformersModel
from twinkle_agentic.envs import OpenEnv, EnvTool
from twinkle_agentic.rollout.multi_turn import MultiTurnRollout
from twinkle_agentic.tools.tool_manager import ToolManager
MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
MODEL_GPUS, SAMPLER_GPUS = 4, 4
MAX_STEPS = 1000
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, MODEL_GPUS + SAMPLER_GPUS)), device_type='GPU'),
]
twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_groups)
model = TransformersModel(model_id=MODEL_ID, remote_group='model')
sampler = vLLMSampler(model_id=MODEL_ID, remote_group='sampler')
sampling_params = SamplingParams(max_tokens=2048)
rollout = MultiTurnRollout(sampler=sampler,
sampling_params=sampling_params, max_turns=6)
for step in range(MAX_STEPS):
# 1. 重置环境,获取初始观测
# prepare_trajectories: 创建 OpenEnv 连接,重置环境,构建初始 trajectory 字典
trajectories, tool_managers, env_tools = prepare_trajectories(
n_trajectories=BATCH_SIZE * 8, env_url='http://localhost:8000',
tool_schema=TOOL_SCHEMA, system_prompt=SYSTEM_PROMPT)
# 2. 多轮 rollout:模型生成 tool call,环境返回观测
all_trajectories = rollout(trajectories, tool_manager=tool_managers)
# 3. 从环境提取 episode reward
# extract_rewards: 从每个 EnvTool 实例读取累积奖励
rewards = extract_rewards(env_tools)
# 4. GRPO advantage → forward_backward → step
advantages = GRPOAdvantage()(rewards, num_generations=8, scale='group')
model.forward_backward(inputs=all_trajectories, advantages=advantages)
model.clip_grad_and_step()