Source code for trinity.trainer.verl.fsdp_workers

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""

import logging
import os
import warnings

import psutil
import torch
import torch.distributed
import verl.utils.torch_functional as verl_F
from codetiming import Timer
from omegaconf import DictConfig, open_dict
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FlatParameter
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
    get_fsdp_wrap_policy,
    get_init_weight_context_manager,
    init_fn,
    load_fsdp_model_to_gpu,
    load_fsdp_optimizer,
    offload_fsdp_model_to_cpu,
    offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

from trinity.common.constants import AlgorithmType, SyncMethod
from trinity.utils.distributed import init_process_group, is_ipv6_address

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))


[docs] def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: device_mesh = init_device_mesh( "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] ) return device_mesh
[docs] def get_sharding_strategy(device_mesh): from torch.distributed.fsdp import ShardingStrategy if device_mesh.ndim == 1: sharding_strategy = ShardingStrategy.FULL_SHARD elif device_mesh.ndim == 2: sharding_strategy = ShardingStrategy.HYBRID_SHARD else: raise NotImplementedError( f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2" ) return sharding_strategy
[docs] class ActorRolloutRefWorker(Worker): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy or a hybrid engine based on the config.rollout """
[docs] def __init__(self, config: DictConfig, role: str): super().__init__() self.config = config import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") # build device mesh for FSDP world_size = torch.distributed.get_world_size() # TODO(sgm): support FSDP hybrid shard for larger model self.device_mesh = create_device_mesh( world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size ) # build device mesh for Ulysses Sequence Parallel self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.actor.get( "ulysses_sequence_parallel_size", 1 ) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self.role = role assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] self._is_ref = self.role in ["ref", "actor_rollout_ref"] self._is_offload_param = False self._is_offload_optimizer = False if self._is_actor: self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) self._is_offload_optimizer = self.config.actor.fsdp_config.get( "optimizer_offload", False ) elif self._is_ref: # TODO: it seems that manual offload is slowly than FSDP offload self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) # normalize config if self._is_actor: self.config.actor.ppo_mini_batch_size *= self.config.rollout.n self.config.actor.ppo_mini_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) assert ( self.config.actor.ppo_mini_batch_size > 0 ), f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization" # micro bsz if self.config.actor.ppo_micro_batch_size is not None: self.config.actor.ppo_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) self.config.actor.ppo_micro_batch_size_per_gpu = ( self.config.actor.ppo_micro_batch_size ) assert ( self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0 ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" assert ( self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0 ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" # normalize rollout config if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: self.config.rollout.log_prob_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) self.config.rollout.log_prob_micro_batch_size_per_gpu = ( self.config.rollout.log_prob_micro_batch_size ) # normalize ref config if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: self.config.ref.log_prob_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) self.config.ref.log_prob_micro_batch_size_per_gpu = ( self.config.ref.log_prob_micro_batch_size )
def _build_model_optimizer( self, model_path, fsdp_config, optim_config, override_model_config, use_remove_padding=False, enable_gradient_checkpointing=False, trust_remote_code=False, use_liger=False, role="actor", ): from torch import optim from torch.distributed.fsdp import CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision from transformers import ( AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, ) from verl.utils.model import ( get_generation_config, print_model_size, update_model_config, ) from verl.utils.torch_dtypes import PrecisionType assert role in ["actor", "ref"] log_gpu_memory_usage("Before init from HF AutoModel", logger=logger) local_path = copy_to_local(model_path) # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) torch_dtype = fsdp_config.get("model_dtype", None) if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: torch_dtype = PrecisionType.to_dtype(torch_dtype) # override model kwargs actor_model_config = AutoConfig.from_pretrained( local_path, trust_remote_code=trust_remote_code ) self.generation_config = get_generation_config( local_path, trust_remote_code=trust_remote_code ) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_model_config) update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) if self.rank == 0: print(f"Model config after override: {actor_model_config}") # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang init_context = get_init_weight_context_manager( use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): actor_module_class = AutoModelForVision2Seq else: actor_module_class = AutoModelForCausalLM actor_module = actor_module_class.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, attn_implementation="flash_attention_2", trust_remote_code=trust_remote_code, ) if use_remove_padding or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch apply_monkey_patch( model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size ) # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import ( _apply_liger_kernel_to_instance, ) _apply_liger_kernel_to_instance(model=actor_module) # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 actor_module.to(torch_dtype) if enable_gradient_checkpointing: actor_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) torch.distributed.barrier() if self.rank == 0: print_model_size(actor_module) log_gpu_memory_usage("After init from HF AutoModel", logger=logger) # We wrap FSDP for rollout as well mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) reduce_dtype = PrecisionType.to_dtype( mixed_precision_config.get("reduce_dtype", "fp32") ) buffer_dtype = PrecisionType.to_dtype( mixed_precision_config.get("buffer_dtype", "fp32") ) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 mixed_precision = MixedPrecision( param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype ) auto_wrap_policy = get_fsdp_wrap_policy( module=actor_module, config=fsdp_config.get("wrap_policy", None) ) if self._is_rollout and self.config.rollout.name == "hf": # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma auto_wrap_policy = None print(f"wrap_policy: {auto_wrap_policy}") fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) # TODO: add transformer policy # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) actor_module_fsdp = FSDP( actor_module, cpu_offload=cpu_offload, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, device_mesh=self.device_mesh, forward_prefetch=False, ) log_gpu_memory_usage("After Actor FSDP init", logger=logger) # TODO: add more optimizer args into config if role == "actor" and optim_config is not None: beta1 = optim_config.get("beta1", 0.9) beta2 = optim_config.get("beta2", 0.999) actor_optimizer = optim.AdamW( actor_module_fsdp.parameters(), lr=optim_config.lr, betas=(beta1, beta2), weight_decay=optim_config.get("weight_decay", 1e-2), ) total_steps = optim_config.get("total_training_steps", 0) num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) if num_warmup_steps < 0: num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") if optim_config.warmup_style == "constant": from verl.utils.torch_functional import ( get_constant_schedule_with_warmup, ) actor_lr_scheduler = get_constant_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps ) elif optim_config.warmup_style == "cosine": from verl.utils.torch_functional import get_cosine_schedule_with_warmup assert ( total_steps > 0 ), "Cosine scheduler of actor requires total_training_steps > 0" actor_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, min_lr_ratio=optim_config.min_lr_ratio, ) else: raise NotImplementedError( f"Lr scheduler style {optim_config.warmup_style} is not supported" ) else: actor_optimizer = None actor_lr_scheduler = None log_gpu_memory_usage("After actor optimizer init", logger=logger) return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config def _build_rollout(self): from torch.distributed.device_mesh import init_device_mesh # TODO(sgm): support FSDP hybrid shard for larger model infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp assert ( self.world_size % infer_tp == 0 ), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" rollout_device_mesh = init_device_mesh( "cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] ) if self.config.rollout.name == "hf": from verl.workers.rollout import HFRollout from verl.workers.sharding_manager import BaseShardingManager rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) rollout_sharding_manager = BaseShardingManager() # TODO: a sharding manager that do nothing? elif self.config.rollout.name == "vllm": if self.config.rollout.use_fire_sampling: from verl.workers.rollout.vllm_rollout import ( FIREvLLMRollout as vLLMRollout, ) from verl.workers.rollout.vllm_rollout import vllm_mode else: from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode from verl.workers.sharding_manager import FSDPVLLMShardingManager log_gpu_memory_usage("Before building vllm rollout", logger=None) local_path = copy_to_local(self.config.model.path) if vllm_mode == "customized": rollout = vLLMRollout( actor_module=self.actor_module_fsdp, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, ) elif vllm_mode == "spmd": rollout = vLLMRollout( model_path=local_path, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, device_mesh=rollout_device_mesh, ) else: raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'") log_gpu_memory_usage("After building vllm rollout", logger=None) if torch.distributed.get_world_size() == 1: self.config.rollout.load_format = "dummy_hf" rollout_sharding_manager = FSDPVLLMShardingManager( module=self.actor_module_fsdp, inference_engine=rollout.inference_engine, model_config=self.actor_model_config, full_params="hf" in self.config.rollout.load_format, device_mesh=rollout_device_mesh, ) log_gpu_memory_usage("After building sharding manager", logger=None) return rollout, rollout_sharding_manager
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from .dp_actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) from omegaconf import OmegaConf override_model_config = OmegaConf.to_container( self.config.model.get("override_config", OmegaConf.create()) ) use_remove_padding = self.config.model.get("use_remove_padding", False) if self._is_actor or self._is_rollout: # we need the model for actor and rollout if self._is_actor: optim_config = self.config.actor.optim fsdp_config = self.config.actor.fsdp_config else: optim_config = None fsdp_config = OmegaConf.create() ( self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config, ) = self._build_model_optimizer( model_path=self.config.model.path, fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config, use_remove_padding=use_remove_padding, enable_gradient_checkpointing=self.config.model.get( "enable_gradient_checkpointing", False ), trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", ) # get the original unwrapped module self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) # load from checkpoint if self._is_actor: OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding self.actor = DataParallelPPOActor( config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer, ) if self._is_rollout: self.rollout, self.rollout_sharding_manager = self._build_rollout() if self._is_ref: self.ref_module_fsdp = self._build_model_optimizer( model_path=self.config.model.path, fsdp_config=self.config.ref.fsdp_config, optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", )[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.ref_policy = DataParallelPPOActor( config=self.config.ref, actor_module=self.ref_module_fsdp ) self.checkpoint_manager = FSDPCheckpointManager( model=self.ref_module_fsdp, optimizer=None, lr_scheduler=None, processing_class=self.processor if self.processor is not None else self.tokenizer, ) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.actor.checkpoint.contents, ) torch.cuda.empty_cache()
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def setup_weight_sync_group(self): if ( hasattr(self.config, "synchronizer") and getattr(self.config.synchronizer, "sync_method", None) == SyncMethod.NCCL ): model = self.actor_module_fsdp self.named_modules = [] self.state_dict_meta = [] for name, module in model.named_modules(): if isinstance(module, FSDP): self.named_modules.append((name, module)) for name_prefix, module in self.named_modules: with FSDP.summon_full_params(module, recurse=False): for name, param in module.named_parameters(): if isinstance(param, FlatParameter): continue realname = ( name_prefix[len(FSDP_PREFIX) :] + "." + name if name_prefix else name ) self.state_dict_meta.append((realname, param.dtype, param.shape)) param = None torch.cuda.empty_cache() if torch.distributed.get_rank() == 0: import ray master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") explorer = ray.get_actor("explorer") group_name = "rollout_weight_sync" setup_ref = explorer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) if is_ipv6_address(master_address): # using tcp://ipv6:port will lead to ValueError init_method = f"tcp://[{master_address}]:{master_port}" else: init_method = f"tcp://{master_address}:{master_port}" timeout = self.config.synchronizer.sync_timeout self._model_update_group = init_process_group( backend="nccl", init_method=init_method, timeout=timeout, world_size=world_size, rank=0, group_name=group_name, ) ray.get(setup_ref) torch.distributed.barrier()
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def sync_weight(self): for name_prefix, module in self.named_modules: with FSDP.summon_full_params(module, recurse=False): if torch.distributed.get_rank() == 0: for name, param in module.named_parameters(): if isinstance(param, FlatParameter): continue torch.distributed.broadcast(param, 0, group=self._model_update_group) param = None torch.cuda.empty_cache()
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def set_mode(self, algo_type: AlgorithmType = AlgorithmType.PPO): self.actor.set_mode(algo_type)
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares data = data.to(torch.cuda.current_device()) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: load_fsdp_optimizer( optimizer=self.actor_optimizer, device_id=torch.cuda.current_device() ) log_gpu_memory_usage("Before update policy", logger=logger) with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) # perform training with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(data=data) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops( global_num_tokens, delta_time ) metrics["perf/mfu/actor"] = ( estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size ) metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / ( 1024**3 ) metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) self.actor_lr_scheduler.step() lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr log_gpu_memory_usage("After update policy", logger=logger) # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) return output
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): # Support all hardwares prompts = prompts.to(torch.cuda.current_device()) assert self._is_rollout if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) meta_info = { "eos_token_id": self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id, "pad_token_id": self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id, } prompts.meta_info.update(meta_info) with self.rollout_sharding_manager: # after parameters sync with rollout, offload actor model to CPU if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) output = self.rollout.generate_sequences(prompts=prompts) log_gpu_memory_usage("After rollout generation", logger=logger) output = self.rollout_sharding_manager.postprocess_data(output) output = output.to("cpu") # clear kv cache torch.cuda.empty_cache() log_gpu_memory_usage("After recompute log prob", logger=logger) return output
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares data = data.to(torch.cuda.current_device()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output = self.actor.compute_log_prob(data=data) output = DataProto.from_dict( tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}, ) output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.world_size > 1: self.actor.actor_module._handle.reshard(True) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After compute_log_prob", logger=logger) return output
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): assert self._is_ref # Support all hardwares data = data.to(torch.cuda.current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["temperature"] = self.config.rollout.temperature data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output = self.ref_policy.compute_log_prob(data=data) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.world_size > 1: self.ref_policy.actor_module._handle.reshard(True) torch.cuda.empty_cache() return output
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): # only support save and load ckpt for actor assert self._is_actor import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): if self._is_actor and self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) if self._is_actor and self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_actor and self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def clear_optimizer_state(self): print("Clear actor optimizer state") if self._is_offload_optimizer: load_fsdp_optimizer( optimizer=self.actor_optimizer, device_id=torch.cuda.current_device() ) self.actor_optimizer.state.clear() self.actor_optimizer.zero_grad() if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer)
[docs] class CriticWorker(Worker):
[docs] def __init__(self, config): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") self.config = config # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) # set FSDP offload params self._is_offload_param = self.config.model.fsdp_config.param_offload self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config self.config.ppo_mini_batch_size *= self.config.rollout_n self.config.ppo_mini_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) if self.config.ppo_micro_batch_size is not None: self.config.ppo_micro_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) self.config.forward_micro_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size assert ( self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" assert ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
def _build_critic_model_optimizer(self, config): # the following line is necessary from torch import optim from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision from verl.utils.model import print_model_size from verl.utils.torch_dtypes import PrecisionType local_path = copy_to_local(config.model.path) # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info # using random initialized model from any architecture. May not be the same as Actor. tokenizer_path = copy_to_local(config.model.tokenizer_path) self.tokenizer = hf_tokenizer( tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False) ) self.processor = hf_processor( tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False) ) from omegaconf import OmegaConf override_config = OmegaConf.to_container( self.config.model.get("override_config", OmegaConf.create()) ) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_config) if self.rank == 0: print(f"Critic overriding config {override_config_kwargs}") torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig, AutoModelForTokenClassification trust_remote_code = False critic_model_config = AutoConfig.from_pretrained( local_path, trust_remote_code=trust_remote_code ) critic_model_config.num_labels = 1 init_context = get_init_weight_context_manager( use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") setattr(critic_model_config, "classifier_dropout", 0.0) setattr(critic_model_config, "hidden_dropout", "0") critic_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=critic_model_config, attn_implementation="flash_attention_2", trust_remote_code=trust_remote_code, ) use_remove_padding = config.model.get("use_remove_padding", False) if use_remove_padding or self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch apply_monkey_patch( model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size ) # some parameters may not in torch_dtype critic_module.to(torch_dtype) if config.model.get("enable_gradient_checkpointing", False): critic_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) if self.rank == 0: print_model_size(critic_module) self.critic_model_config = critic_model_config fsdp_config = self.config.model.fsdp_config mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) reduce_dtype = PrecisionType.to_dtype( mixed_precision_config.get("reduce_dtype", "fp32") ) buffer_dtype = PrecisionType.to_dtype( mixed_precision_config.get("buffer_dtype", "fp32") ) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 mixed_precision = MixedPrecision( param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype ) auto_wrap_policy = get_fsdp_wrap_policy( module=critic_module, config=self.config.model.fsdp_config.wrap_policy ) log_gpu_memory_usage("Before critic FSDP", logger=None) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation critic_module = FSDP( critic_module, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, forward_prefetch=False, device_mesh=self.device_mesh, cpu_offload=None, ) log_gpu_memory_usage("After critic FSDP", logger=None) beta1 = config.optim.get("beta1", 0.9) beta2 = config.optim.get("beta2", 0.999) critic_optimizer = optim.AdamW( critic_module.parameters(), lr=config.optim.lr, betas=(beta1, beta2), weight_decay=config.optim.get("weight_decay", 1e-2), ) total_steps = config.optim.get("total_training_steps", 0) num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) if num_warmup_steps < 0: num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") if config.optim.warmup_style == "constant": from verl.utils.torch_functional import get_constant_schedule_with_warmup critic_lr_scheduler = get_constant_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps ) elif config.optim.warmup_style == "cosine": from verl.utils.torch_functional import get_cosine_schedule_with_warmup assert total_steps > 0, "Cosine scheduler of critic requires total_training_steps > 0" critic_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, min_lr_ratio=config.optim.min_lr_ratio, ) else: raise NotImplementedError( f"Lr scheduler style {config.optim.warmup_style} is not supported" ) return critic_module, critic_optimizer, critic_lr_scheduler
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) from verl.workers.critic import DataParallelPPOCritic ( self.critic_module, self.critic_optimizer, self.critic_lr_scheduler, ) = self._build_critic_model_optimizer(self.config) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) self.critic = DataParallelPPOCritic( config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer, ) self.flops_counter = FlopsCounter(self.critic_model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.critic_module, optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.checkpoint.contents, )
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): # Support all hardwares data = data.to(torch.cuda.current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) micro_batch_size = self.config.forward_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) values = self.critic.compute_values(data=data) output = DataProto.from_dict(tensors={"values": values}) output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) return output
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): # Support all hardwares data = data.to(torch.cuda.current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: load_fsdp_optimizer( optimizer=self.critic_optimizer, device_id=torch.cuda.current_device() ) # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) with Timer(name="update_critic", logger=None) as timer: metrics = self.critic.update_critic(data=data) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops( global_num_tokens, delta_time ) metrics["perf/mfu/critic"] = ( estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size ) self.critic_lr_scheduler.step() lr = self.critic_lr_scheduler.get_last_lr()[0] metrics["critic/lr"] = lr output = DataProto(batch=None, meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) output = output.to("cpu") return output
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def clear_optimizer_state(self): print("Clear critic optimizer state") if self._is_offload_optimizer: load_fsdp_optimizer( optimizer=self.critic_optimizer, device_id=torch.cuda.current_device() ) self.critic_optimizer.state.clear() self.critic_optimizer.zero_grad() if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer)
# TODO(sgm): we may need to extract it to dp_reward_model.py
[docs] class RewardModelWorker(Worker): """ Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. """
[docs] def __init__(self, config): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") self.config = config # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self.use_remove_padding = self.config.model.get("use_remove_padding", False) # normalize config if self.config.micro_batch_size is not None: self.config.micro_batch_size //= torch.distributed.get_world_size() self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
def _build_model(self, config): # the following line is necessary from torch.distributed.fsdp import CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import AutoConfig, AutoModelForTokenClassification # download the checkpoint from hdfs local_path = copy_to_local(config.model.path) if self.config.model.input_tokenizer is None: self._do_switch_chat_template = False else: self._do_switch_chat_template = True input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) self.input_tokenizer = hf_tokenizer( input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False), ) self.tokenizer = hf_tokenizer( local_path, trust_remote_code=config.model.get("trust_remote_code", False) ) trust_remote_code = config.model.get("trust_remote_code", False) model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) model_config.num_labels = 1 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect init_context = get_init_weight_context_manager( use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") setattr(model_config, "classifier_dropout", 0.0) reward_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, config=model_config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", trust_remote_code=trust_remote_code, ) if ( config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1 ): from verl.models.transformers.monkey_patch import apply_monkey_patch apply_monkey_patch( model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size ) reward_module.to(torch.bfloat16) auto_wrap_policy = get_fsdp_wrap_policy( module=reward_module, config=self.config.model.fsdp_config ) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) reward_module = FSDP( reward_module, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), forward_prefetch=False, device_mesh=self.device_mesh, ) return reward_module
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) self.reward_module = self._build_model(config=self.config)
def _forward_micro_batch(self, micro_batch): from flash_attn.bert_padding import ( index_first_axis, pad_input, rearrange, unpad_input, ) from verl.utils.ulysses import ( gather_outpus_and_unpad, ulysses_pad_and_slice_inputs, ) with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] position_ids = micro_batch["position_ids"] if self.use_remove_padding: input_ids_rmpad, indices, *_ = unpad_input( input_ids.unsqueeze(-1), attention_mask ) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) # unpad the position_ids to align the rotary position_ids_rmpad = index_first_axis( rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size, ) # only pass input_ids and position_ids to enable flash_attn_varlen output = self.reward_module( input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False, ) # prevent model thinks we are generating reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: reward_rmpad = gather_outpus_and_unpad( reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size ) # pad it back rm_score = pad_input( reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen ).squeeze(-1) else: output = self.reward_module( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False, ) rm_score = output.logits # (batch_size, seq_len, 1) rm_score = rm_score.squeeze(-1) # extract the result of the last valid token eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] return rm_score def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): batch_size = data.batch.batch_size[0] # expand as token_level_reward attention_mask = data.batch["attention_mask"] position_ids = data.batch["position_ids"] response_length = data.batch["responses"].shape[-1] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores # select the response part token_level_scores = token_level_scores[:, -response_length:] return token_level_scores def _switch_chat_template(self, data: DataProto): src_max_length = data.batch["attention_mask"].shape[-1] src_tokenizer = self.input_tokenizer target_tokenizer = self.tokenizer rm_input_ids = [] rm_attention_mask = [] for i in range(data.batch.batch_size[0]): # extract raw prompt chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() # extract response response_ids = data.batch["responses"][i] response_length = response_ids.shape[-1] valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode response = src_tokenizer.decode(valid_response_ids) # remove bos and eos response = response.replace(src_tokenizer.eos_token, "") chat.append({"role": "assistant", "content": response}) prompt_with_chat_template = target_tokenizer.apply_chat_template( chat, add_generation_prompt=False, tokenize=False ) if self.rank == 0 and i == 0: # for debugging purpose print(f"Switch template. chat: {prompt_with_chat_template}") # the maximum length is actually determined by the reward model itself max_length = self.config.get("max_length", src_max_length) if max_length is None: max_length = src_max_length input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( prompt=prompt_with_chat_template, tokenizer=target_tokenizer, max_length=max_length, pad_token_id=target_tokenizer.pad_token_id, left_pad=False, # right padding truncation=self.config.get("truncation", "right"), ) # truncate from the right rm_input_ids.append(input_ids) rm_attention_mask.append(attention_mask) rm_input_ids = torch.cat(rm_input_ids, dim=0) rm_attention_mask = torch.cat(rm_attention_mask, dim=0) rm_position_ids = compute_position_id_with_mask(rm_attention_mask) rm_inputs = { "input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids, } return DataProto.from_dict(rm_inputs)
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto): import itertools from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches # Support all hardwares data = data.to(torch.cuda.current_device()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) # Support all hardwares rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) # perform forward computation with self.ulysses_sharding_manager: rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data) data = self.ulysses_sharding_manager.preprocess_data(data=data) use_dynamic_bsz = self.config.use_dynamic_bsz if use_dynamic_bsz: max_token_len = ( self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size ) micro_batches, indices = rearrange_micro_batches( batch=rm_data.batch, max_token_len=max_token_len ) else: micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) output = [] for micro_batch in micro_batches: rm_score = self._forward_micro_batch(micro_batch) output.append(rm_score) scores = torch.cat(output, dim=0) # (batch_size) if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) scores = scores[revert_indices] token_level_scores = self._expand_to_token_level(data, scores) # Note that this is only the scores, may not be the final rewards used to train RL output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) output = self.ulysses_sharding_manager.postprocess_data(data=output) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module self.reward_module._handle.reshard(True) output = output.to("cpu") return output