# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""
import logging
import os
import warnings
import psutil
import torch
import torch.distributed
import verl.utils.torch_functional as verl_F
from codetiming import Timer
from omegaconf import DictConfig, open_dict
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FlatParameter
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from trinity.common.constants import AlgorithmType, SyncMethod
from trinity.utils.distributed import init_process_group, is_ipv6_address
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
[docs]
def create_device_mesh(world_size, fsdp_size):
if fsdp_size < 0 or fsdp_size >= world_size:
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
)
return device_mesh
[docs]
def get_sharding_strategy(device_mesh):
from torch.distributed.fsdp import ShardingStrategy
if device_mesh.ndim == 1:
sharding_strategy = ShardingStrategy.FULL_SHARD
elif device_mesh.ndim == 2:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
else:
raise NotImplementedError(
f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2"
)
return sharding_strategy
[docs]
class ActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
[docs]
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
# build device mesh for FSDP
world_size = torch.distributed.get_world_size()
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = create_device_mesh(
world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size
)
# build device mesh for Ulysses Sequence Parallel
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.actor.get(
"ulysses_sequence_parallel_size", 1
)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
"cuda",
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.role = role
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._is_offload_param = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get(
"optimizer_offload", False
)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False)
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
assert (
self.config.actor.ppo_mini_batch_size > 0
), f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization"
# micro bsz
if self.config.actor.ppo_micro_batch_size is not None:
self.config.actor.ppo_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.actor.ppo_micro_batch_size_per_gpu = (
self.config.actor.ppo_micro_batch_size
)
assert (
self.config.actor.ppo_mini_batch_size
% self.config.actor.ppo_micro_batch_size_per_gpu
== 0
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
assert (
self.config.actor.ppo_mini_batch_size
// self.config.actor.ppo_micro_batch_size_per_gpu
> 0
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
# normalize rollout config
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
self.config.rollout.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.rollout.log_prob_micro_batch_size_per_gpu = (
self.config.rollout.log_prob_micro_batch_size
)
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.ref.log_prob_micro_batch_size_per_gpu = (
self.config.ref.log_prob_micro_batch_size
)
def _build_model_optimizer(
self,
model_path,
fsdp_config,
optim_config,
override_model_config,
use_remove_padding=False,
enable_gradient_checkpointing=False,
trust_remote_code=False,
use_liger=False,
role="actor",
):
from torch import optim
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
)
from verl.utils.model import (
get_generation_config,
print_model_size,
update_model_config,
)
from verl.utils.torch_dtypes import PrecisionType
assert role in ["actor", "ref"]
log_gpu_memory_usage("Before init from HF AutoModel", logger=logger)
local_path = copy_to_local(model_path)
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
torch_dtype = fsdp_config.get("model_dtype", None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(
local_path, trust_remote_code=trust_remote_code
)
self.generation_config = get_generation_config(
local_path, trust_remote_code=trust_remote_code
)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f"Model config after override: {actor_model_config}")
# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
init_context = get_init_weight_context_manager(
use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
actor_module_class = AutoModelForVision2Seq
else:
actor_module_class = AutoModelForCausalLM
actor_module = actor_module_class.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(
model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size
)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
from liger_kernel.transformers.monkey_patch import (
_apply_liger_kernel_to_instance,
)
_apply_liger_kernel_to_instance(model=actor_module)
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage("After init from HF AutoModel", logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("reduce_dtype", "fp32")
)
buffer_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("buffer_dtype", "fp32")
)
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype
)
auto_wrap_policy = get_fsdp_wrap_policy(
module=actor_module, config=fsdp_config.get("wrap_policy", None)
)
if self._is_rollout and self.config.rollout.name == "hf":
# TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
auto_wrap_policy = None
print(f"wrap_policy: {auto_wrap_policy}")
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# TODO: add transformer policy
# We force reference policy to use CPUOffload to save memory.
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
actor_module_fsdp = FSDP(
actor_module,
cpu_offload=cpu_offload,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh,
forward_prefetch=False,
)
log_gpu_memory_usage("After Actor FSDP init", logger=logger)
# TODO: add more optimizer args into config
if role == "actor" and optim_config is not None:
beta1 = optim_config.get("beta1", 0.9)
beta2 = optim_config.get("beta2", 0.999)
actor_optimizer = optim.AdamW(
actor_module_fsdp.parameters(),
lr=optim_config.lr,
betas=(beta1, beta2),
weight_decay=optim_config.get("weight_decay", 1e-2),
)
total_steps = optim_config.get("total_training_steps", 0)
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
if num_warmup_steps < 0:
num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if optim_config.warmup_style == "constant":
from verl.utils.torch_functional import (
get_constant_schedule_with_warmup,
)
actor_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
)
elif optim_config.warmup_style == "cosine":
from verl.utils.torch_functional import get_cosine_schedule_with_warmup
assert (
total_steps > 0
), "Cosine scheduler of actor requires total_training_steps > 0"
actor_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
min_lr_ratio=optim_config.min_lr_ratio,
)
else:
raise NotImplementedError(
f"Lr scheduler style {optim_config.warmup_style} is not supported"
)
else:
actor_optimizer = None
actor_lr_scheduler = None
log_gpu_memory_usage("After actor optimizer init", logger=logger)
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
def _build_rollout(self):
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert (
self.world_size % infer_tp == 0
), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
rollout_device_mesh = init_device_mesh(
"cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
if self.config.rollout.name == "hf":
from verl.workers.rollout import HFRollout
from verl.workers.sharding_manager import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
rollout_sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == "vllm":
if self.config.rollout.use_fire_sampling:
from verl.workers.rollout.vllm_rollout import (
FIREvLLMRollout as vLLMRollout,
)
from verl.workers.rollout.vllm_rollout import vllm_mode
else:
from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
from verl.workers.sharding_manager import FSDPVLLMShardingManager
log_gpu_memory_usage("Before building vllm rollout", logger=None)
local_path = copy_to_local(self.config.model.path)
if vllm_mode == "customized":
rollout = vLLMRollout(
actor_module=self.actor_module_fsdp,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
)
elif vllm_mode == "spmd":
rollout = vLLMRollout(
model_path=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
device_mesh=rollout_device_mesh,
)
else:
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
log_gpu_memory_usage("After building vllm rollout", logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
)
log_gpu_memory_usage("After building sharding manager", logger=None)
return rollout, rollout_sharding_manager
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from .dp_actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(
self.config.model.get("override_config", OmegaConf.create())
)
use_remove_padding = self.config.model.get("use_remove_padding", False)
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
(
self.actor_module_fsdp,
self.actor_optimizer,
self.actor_lr_scheduler,
self.actor_model_config,
) = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
enable_gradient_checkpointing=self.config.model.get(
"enable_gradient_checkpointing", False
),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
)
# get the original unwrapped module
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.actor = DataParallelPPOActor(
config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer,
)
if self._is_rollout:
self.rollout, self.rollout_sharding_manager = self._build_rollout()
if self._is_ref:
self.ref_module_fsdp = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
)[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.ref_policy = DataParallelPPOActor(
config=self.config.ref, actor_module=self.ref_module_fsdp
)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.ref_module_fsdp,
optimizer=None,
lr_scheduler=None,
processing_class=self.processor if self.processor is not None else self.tokenizer,
)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.actor.checkpoint.contents,
)
torch.cuda.empty_cache()
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def setup_weight_sync_group(self):
if (
hasattr(self.config, "synchronizer")
and getattr(self.config.synchronizer, "sync_method", None) == SyncMethod.NCCL
):
model = self.actor_module_fsdp
self.named_modules = []
self.state_dict_meta = []
for name, module in model.named_modules():
if isinstance(module, FSDP):
self.named_modules.append((name, module))
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
realname = (
name_prefix[len(FSDP_PREFIX) :] + "." + name if name_prefix else name
)
self.state_dict_meta.append((realname, param.dtype, param.shape))
param = None
torch.cuda.empty_cache()
if torch.distributed.get_rank() == 0:
import ray
master_address, master_port = self.get_availale_master_addr_port()
world_size = self.config.synchronizer.explorer_world_size + 1
print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).")
explorer = ray.get_actor("explorer")
group_name = "rollout_weight_sync"
setup_ref = explorer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
else:
init_method = f"tcp://{master_address}:{master_port}"
timeout = self.config.synchronizer.sync_timeout
self._model_update_group = init_process_group(
backend="nccl",
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=0,
group_name=group_name,
)
ray.get(setup_ref)
torch.distributed.barrier()
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def sync_weight(self):
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
if torch.distributed.get_rank() == 0:
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
torch.distributed.broadcast(param, 0, group=self._model_update_group)
param = None
torch.cuda.empty_cache()
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_mode(self, algo_type: AlgorithmType = AlgorithmType.PPO):
self.actor.set_mode(algo_type)
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
# Support all hardwares
data = data.to(torch.cuda.current_device())
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(
optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()
)
log_gpu_memory_usage("Before update policy", logger=logger)
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
# perform training
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
global_num_tokens, delta_time
)
metrics["perf/mfu/actor"] = (
estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
)
metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (
1024**3
)
metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
self.actor_lr_scheduler.step()
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
log_gpu_memory_usage("After update policy", logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
# Support all hardwares
prompts = prompts.to(torch.cuda.current_device())
assert self._is_rollout
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
meta_info = {
"eos_token_id": self.generation_config.eos_token_id
if self.generation_config is not None
else self.tokenizer.eos_token_id,
"pad_token_id": self.generation_config.pad_token_id
if self.generation_config is not None
else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
with self.rollout_sharding_manager:
# after parameters sync with rollout, offload actor model to CPU
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage("After rollout generation", logger=logger)
output = self.rollout_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# clear kv cache
torch.cuda.empty_cache()
log_gpu_memory_usage("After recompute log prob", logger=logger)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
# Support all hardwares
data = data.to(torch.cuda.current_device())
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.actor.compute_log_prob(data=data)
output = DataProto.from_dict(
tensors={"old_log_probs": output},
meta_info={"temperature": self.config.rollout.temperature},
)
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After compute_log_prob", logger=logger)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
# Support all hardwares
data = data.to(torch.cuda.current_device())
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["temperature"] = self.config.rollout.temperature
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.ref_policy.actor_module._handle.reshard(True)
torch.cuda.empty_cache()
return output
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
# only support save and load ckpt for actor
assert self._is_actor
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.save_checkpoint(
local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep,
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
if self._is_actor and self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
if self._is_actor and self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_actor and self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_optimizer_state(self):
print("Clear actor optimizer state")
if self._is_offload_optimizer:
load_fsdp_optimizer(
optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()
)
self.actor_optimizer.state.clear()
self.actor_optimizer.zero_grad()
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
[docs]
class CriticWorker(Worker):
[docs]
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
"cuda",
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# set FSDP offload params
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.ppo_mini_batch_size *= self.config.rollout_n
self.config.ppo_mini_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
if self.config.ppo_micro_batch_size is not None:
self.config.ppo_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.forward_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
assert (
self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
assert (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
def _build_critic_model_optimizer(self, config):
# the following line is necessary
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import print_model_size
from verl.utils.torch_dtypes import PrecisionType
local_path = copy_to_local(config.model.path)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
tokenizer_path = copy_to_local(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(
tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
self.processor = hf_processor(
tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(
self.config.model.get("override_config", OmegaConf.create())
)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f"Critic overriding config {override_config_kwargs}")
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig, AutoModelForTokenClassification
trust_remote_code = False
critic_model_config = AutoConfig.from_pretrained(
local_path, trust_remote_code=trust_remote_code
)
critic_model_config.num_labels = 1
init_context = get_init_weight_context_manager(
use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(critic_model_config, "classifier_dropout", 0.0)
setattr(critic_model_config, "hidden_dropout", "0")
critic_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=critic_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
use_remove_padding = config.model.get("use_remove_padding", False)
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(
model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size
)
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
if config.model.get("enable_gradient_checkpointing", False):
critic_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
if self.rank == 0:
print_model_size(critic_module)
self.critic_model_config = critic_model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("reduce_dtype", "fp32")
)
buffer_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("buffer_dtype", "fp32")
)
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype
)
auto_wrap_policy = get_fsdp_wrap_policy(
module=critic_module, config=self.config.model.fsdp_config.wrap_policy
)
log_gpu_memory_usage("Before critic FSDP", logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation
critic_module = FSDP(
critic_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None,
)
log_gpu_memory_usage("After critic FSDP", logger=None)
beta1 = config.optim.get("beta1", 0.9)
beta2 = config.optim.get("beta2", 0.999)
critic_optimizer = optim.AdamW(
critic_module.parameters(),
lr=config.optim.lr,
betas=(beta1, beta2),
weight_decay=config.optim.get("weight_decay", 1e-2),
)
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if config.optim.warmup_style == "constant":
from verl.utils.torch_functional import get_constant_schedule_with_warmup
critic_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
)
elif config.optim.warmup_style == "cosine":
from verl.utils.torch_functional import get_cosine_schedule_with_warmup
assert total_steps > 0, "Cosine scheduler of critic requires total_training_steps > 0"
critic_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=critic_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
min_lr_ratio=config.optim.min_lr_ratio,
)
else:
raise NotImplementedError(
f"Lr scheduler style {config.optim.warmup_style} is not supported"
)
return critic_module, critic_optimizer, critic_lr_scheduler
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from verl.workers.critic import DataParallelPPOCritic
(
self.critic_module,
self.critic_optimizer,
self.critic_lr_scheduler,
) = self._build_critic_model_optimizer(self.config)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
self.critic = DataParallelPPOCritic(
config=self.config,
critic_module=self.critic_module,
critic_optimizer=self.critic_optimizer,
)
self.flops_counter = FlopsCounter(self.critic_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.critic_module,
optimizer=self.critic_optimizer,
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.checkpoint.contents,
)
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
micro_batch_size = self.config.forward_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
if self._is_offload_optimizer:
load_fsdp_optimizer(
optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()
)
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_critic", logger=None) as timer:
metrics = self.critic.update_critic(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
global_num_tokens, delta_time
)
metrics["perf/mfu/critic"] = (
estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
)
self.critic_lr_scheduler.step()
lr = self.critic_lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr
output = DataProto(batch=None, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
output = output.to("cpu")
return output
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.save_checkpoint(
local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep,
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_optimizer_state(self):
print("Clear critic optimizer state")
if self._is_offload_optimizer:
load_fsdp_optimizer(
optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()
)
self.critic_optimizer.state.clear()
self.critic_optimizer.zero_grad()
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)
# TODO(sgm): we may need to extract it to dp_reward_model.py
[docs]
class RewardModelWorker(Worker):
"""
Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.
"""
[docs]
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
"cuda",
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.use_remove_padding = self.config.model.get("use_remove_padding", False)
# normalize config
if self.config.micro_batch_size is not None:
self.config.micro_batch_size //= torch.distributed.get_world_size()
self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
def _build_model(self, config):
# the following line is necessary
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoConfig, AutoModelForTokenClassification
# download the checkpoint from hdfs
local_path = copy_to_local(config.model.path)
if self.config.model.input_tokenizer is None:
self._do_switch_chat_template = False
else:
self._do_switch_chat_template = True
input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)
self.input_tokenizer = hf_tokenizer(
input_tokenizer_local_path,
trust_remote_code=config.model.get("trust_remote_code", False),
)
self.tokenizer = hf_tokenizer(
local_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
trust_remote_code = config.model.get("trust_remote_code", False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
model_config.num_labels = 1
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(model_config, "classifier_dropout", 0.0)
reward_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
config=model_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
if (
config.model.get("use_remove_padding", False)
or self.ulysses_sequence_parallel_size > 1
):
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(
model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size
)
reward_module.to(torch.bfloat16)
auto_wrap_policy = get_fsdp_wrap_policy(
module=reward_module, config=self.config.model.fsdp_config
)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy, # zero3
sync_module_states=True,
cpu_offload=CPUOffload(offload_params=True),
forward_prefetch=False,
device_mesh=self.device_mesh,
)
return reward_module
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
self.reward_module = self._build_model(config=self.config)
def _forward_micro_batch(self, micro_batch):
from flash_attn.bert_padding import (
index_first_axis,
pad_input,
rearrange,
unpad_input,
)
from verl.utils.ulysses import (
gather_outpus_and_unpad,
ulysses_pad_and_slice_inputs,
)
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad,
position_ids_rmpad,
sp_size=self.ulysses_sequence_parallel_size,
)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.reward_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False,
) # prevent model thinks we are generating
reward_rmpad = output.logits
reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
reward_rmpad = gather_outpus_and_unpad(
reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
# pad it back
rm_score = pad_input(
reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen
).squeeze(-1)
else:
output = self.reward_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
)
rm_score = output.logits # (batch_size, seq_len, 1)
rm_score = rm_score.squeeze(-1)
# extract the result of the last valid token
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]
return rm_score
def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
batch_size = data.batch.batch_size[0]
# expand as token_level_reward
attention_mask = data.batch["attention_mask"]
position_ids = data.batch["position_ids"]
response_length = data.batch["responses"].shape[-1]
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen)
token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores
# select the response part
token_level_scores = token_level_scores[:, -response_length:]
return token_level_scores
def _switch_chat_template(self, data: DataProto):
src_max_length = data.batch["attention_mask"].shape[-1]
src_tokenizer = self.input_tokenizer
target_tokenizer = self.tokenizer
rm_input_ids = []
rm_attention_mask = []
for i in range(data.batch.batch_size[0]):
# extract raw prompt
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
# extract response
response_ids = data.batch["responses"][i]
response_length = response_ids.shape[-1]
valid_response_length = data.batch["attention_mask"][i][-response_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
response = src_tokenizer.decode(valid_response_ids)
# remove bos and eos
response = response.replace(src_tokenizer.eos_token, "")
chat.append({"role": "assistant", "content": response})
prompt_with_chat_template = target_tokenizer.apply_chat_template(
chat, add_generation_prompt=False, tokenize=False
)
if self.rank == 0 and i == 0:
# for debugging purpose
print(f"Switch template. chat: {prompt_with_chat_template}")
# the maximum length is actually determined by the reward model itself
max_length = self.config.get("max_length", src_max_length)
if max_length is None:
max_length = src_max_length
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt_with_chat_template,
tokenizer=target_tokenizer,
max_length=max_length,
pad_token_id=target_tokenizer.pad_token_id,
left_pad=False, # right padding
truncation=self.config.get("truncation", "right"),
) # truncate from the right
rm_input_ids.append(input_ids)
rm_attention_mask.append(attention_mask)
rm_input_ids = torch.cat(rm_input_ids, dim=0)
rm_attention_mask = torch.cat(rm_attention_mask, dim=0)
rm_position_ids = compute_position_id_with_mask(rm_attention_mask)
rm_inputs = {
"input_ids": rm_input_ids,
"attention_mask": rm_attention_mask,
"position_ids": rm_position_ids,
}
return DataProto.from_dict(rm_inputs)
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
import itertools
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._do_switch_chat_template:
rm_data = self._switch_chat_template(data)
# Support all hardwares
rm_data.batch = rm_data.batch.to(torch.cuda.current_device())
# perform forward computation
with self.ulysses_sharding_manager:
rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)
data = self.ulysses_sharding_manager.preprocess_data(data=data)
use_dynamic_bsz = self.config.use_dynamic_bsz
if use_dynamic_bsz:
max_token_len = (
self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
)
micro_batches, indices = rearrange_micro_batches(
batch=rm_data.batch, max_token_len=max_token_len
)
else:
micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
output = []
for micro_batch in micro_batches:
rm_score = self._forward_micro_batch(micro_batch)
output.append(rm_score)
scores = torch.cat(output, dim=0) # (batch_size)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
scores = scores[revert_indices]
token_level_scores = self._expand_to_token_level(data, scores)
# Note that this is only the scores, may not be the final rewards used to train RL
output = DataProto.from_dict(tensors={"rm_scores": token_level_scores})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
self.reward_module._handle.reshard(True)
output = output.to("cpu")
return output