Asynchronous RFT

This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.

Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.

For this purpose, we prepare two main config files: explorer.yaml and trainer.yaml. The main difference between them is that in explorer.yaml we set mode as explore, while in trainer.yaml we set mode as train. The model weights of the explorer and trainer are synchronized once every sync_interval * batch_size tasks.

Suppose we have a node of 8 GPUs; we use 4 GPUs for the trainer and 4 GPUs for the explorer. Some important setups of explorer.yaml are listed in the following:

project: <project_name>
name: <experiment_name>
mode: explore
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
  algorithm_type: grpo
  repeat_times: 8
model:
  model_path: /PATH/TO/MODEL/
cluster:
  node_num: 1
  gpu_per_node: 4
buffer:
  total_epochs: 1
  batch_size: 96
  explorer_input:
    taskset:
      name: gsm8k
      storage_type: file
      path: /PATH/TO/DATASET/
      split: train
      format:
        prompt_key: 'question'
        response_key: 'answer'
      rollout_args:
        temperature: 1.0
    default_workflow_type: 'math_workflow'
  trainer_input:
    experience_buffer:
      name: gsm8k_buffer
      storage_type: queue
      path: 'sqlite:///gsm8k.db'
explorer:
  eval_interval: 10
  runner_num: 32
  rollout_model:
    engine_type: vllm_async
    engine_num: 4
synchronizer:
  sync_method: 'checkpoint'
  sync_interval: 10
trainer:
  trainer_config_path: examples/async_gsm8k/verl_config.yaml

Some important setups of trainer.yaml are listed in the following:

project: <project_name>
name: <experiment_name>
mode: train
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
  algorithm_type: grpo
  repeat_times: 8
model:
  model_path: /PATH/TO/MODEL/
cluster:
  node_num: 1
  gpu_per_node: 4
buffer:
  total_epochs: 1
  batch_size: 96
  explorer_input:
    taskset:
      name: gsm8k
      storage_type: file
      path: /PATH/TO/DATASET/
      format:
        prompt_key: 'question'
        response_key: 'answer'
      rollout_args:
        temperature: 1.0
    default_workflow_type: 'math_workflow'
  trainer_input:
    experience_buffer:
      name: gsm8k_buffer
      storage_type: queue
      path: 'sqlite:///gsm8k.db'
synchronizer:
  sync_method: 'checkpoint'
  sync_interval: 10
trainer:
  trainer_config_path: examples/async_gsm8k/verl_config.yaml

You may run this example with the following command:

bash examples/async_gsm8k/run.sh

The following plot shows the learning curve of GRPO in the asynchronous mode.

This result should be regarded merely as a baseline, since GRPO is supposed to be an on-policy algorithm. We are continuously investigating other RL algorithms (e.g., OPMD) in the asynchronous mode.

async