Source code for trinity.trainer.verl.fsdp_workers

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm.
Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/workers/fsdp_workers.py
"""

import json
import logging
import os
import warnings
from contextlib import contextmanager
from dataclasses import asdict
from datetime import timedelta

import psutil
import torch
import torch.distributed
import torch.distributed as dist
import vllm  # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically.
from codetiming import Timer
from omegaconf import DictConfig, OmegaConf, open_dict
from peft import LoraConfig, TaskType, get_peft_model
from safetensors.torch import save_file
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FlatParameter
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX
from verl import DataProto
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import (
    get_device_id,
    get_device_name,
    get_nccl_backend,
    get_torch_device,
)
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
    CPUOffloadPolicy,
    MixedPrecisionPolicy,
    apply_fsdp2,
    fsdp2_load_full_state_dict,
    fsdp_version,
    get_fsdp_wrap_policy,
    get_init_weight_context_manager,
    init_fn,
    layered_summon_lora_params,
    load_fsdp_model_to_gpu,
    load_fsdp_optimizer,
    offload_fsdp_model_to_cpu,
    offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.py_functional import convert_to_regular_types
from verl.workers.fsdp_workers import (
    create_device_mesh,
    device_name,
    get_sharding_strategy,
)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

from trinity.common.config import AlgorithmConfig
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
from trinity.manager.synchronizer import Synchronizer
from trinity.trainer.verl.fsdp_checkpoint_manager import FSDPCheckpointManager
from trinity.utils.distributed import init_process_group

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


[docs] class ActorRolloutRefWorker(Worker): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy or a hybrid engine based on the config.rollout """
[docs] def __init__(self, config: DictConfig, role: str): super().__init__() self.config = config import torch.distributed if not torch.distributed.is_initialized(): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) torch.distributed.init_process_group( backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", rank=rank, world_size=world_size, init_method=os.environ.get("DIST_INIT_METHOD", None), timeout=timedelta(seconds=self.config.synchronizer.sync_timeout), ) # build device mesh for FSDP world_size = torch.distributed.get_world_size() # TODO(sgm): support FSDP hybrid shard for larger model self.device_mesh = create_device_mesh( world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size ) # build device mesh for Ulysses Sequence Parallel self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.actor.get( "ulysses_sequence_parallel_size", 1 ) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self._lora_rank = self.config.model.get("lora_rank", 0) self._is_lora = self._lora_rank > 0 self.role = role assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] self._is_ref = self.role in ["ref", "actor_rollout_ref"] self._is_offload_param = False self._is_offload_optimizer = False if self._is_actor: self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) self._is_offload_optimizer = self.config.actor.fsdp_config.get( "optimizer_offload", False ) elif self._is_ref: # TODO: it seems that manual offload is slowly than FSDP offload self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) # normalize config if self._is_actor: # note: no need to conduct `ppo_mini_batch_size *= rollout_n` anymore self.config.actor.ppo_mini_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) assert ( self.config.actor.ppo_mini_batch_size > 0 ), f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization" # micro bsz if self.config.actor.ppo_micro_batch_size is not None: self.config.actor.ppo_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) self.config.actor.ppo_micro_batch_size_per_gpu = ( self.config.actor.ppo_micro_batch_size ) if self.config.actor.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0 ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" assert ( self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0 ), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" # normalize ref config if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: self.config.ref.log_prob_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) self.config.ref.log_prob_micro_batch_size_per_gpu = ( self.config.ref.log_prob_micro_batch_size )
@contextmanager def _fsdp_offload_context(self): """A context manager to handle FSDP model GPU loading and CPU offloading.""" if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) try: yield finally: if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) torch.distributed.barrier() torch.cuda.empty_cache() def _build_model_optimizer( # noqa: C901 self, model_path, fsdp_config, optim_config, override_model_config, use_remove_padding=False, use_fused_kernels=False, enable_gradient_checkpointing=False, trust_remote_code=False, use_liger=False, role="actor", enable_activation_offload=False, ): from torch import optim from torch.distributed.fsdp import CPUOffload, MixedPrecision from transformers import ( AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, ) from verl.utils.model import ( get_generation_config, print_model_size, update_model_config, ) from verl.utils.torch_dtypes import PrecisionType assert role in ["actor", "ref"] log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) local_path = model_path # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) if self.config.model.get("custom_chat_template", None) is not None: if self.processor is not None: self.processor.chat_template = self.config.model.custom_chat_template else: self.tokenizer.chat_template = self.config.model.custom_chat_template torch_dtype = fsdp_config.get("model_dtype", None) if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: torch_dtype = PrecisionType.to_dtype(torch_dtype) # override model kwargs actor_model_config = AutoConfig.from_pretrained( local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" ) # patch for kimi-vl if getattr(actor_model_config, "model_type", None) == "kimi_vl": actor_model_config.text_config.topk_method = "greedy" self.generation_config = get_generation_config( local_path, trust_remote_code=trust_remote_code ) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_model_config) update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) if self.rank == 0: print(f"Model config after override: {actor_model_config}") # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang init_context = get_init_weight_context_manager( use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): actor_module_class = AutoModelForVision2Seq else: actor_module_class = AutoModelForCausalLM actor_module = actor_module_class.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, trust_remote_code=trust_remote_code, ) # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import ( _apply_liger_kernel_to_instance, ) _apply_liger_kernel_to_instance(model=actor_module) fused_kernel_options = self.config.model.get("fused_kernel_options", None) fused_kernels_backend = ( fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None ) apply_monkey_patch( model=actor_module, use_remove_padding=use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend, ) # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 actor_module.to(torch_dtype) if enable_gradient_checkpointing: actor_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) if self._is_lora: print("Applying LoRA to actor module") actor_module.enable_input_require_grads() # Convert config to regular Python types before creating PEFT model lora_config = { "task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, "target_modules": convert_to_regular_types(self.config.model.target_modules), "bias": "none", } actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) torch.distributed.barrier() if self.rank == 0: print_model_size(actor_module) log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) # We wrap FSDP for rollout as well mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) reduce_dtype = PrecisionType.to_dtype( mixed_precision_config.get("reduce_dtype", "fp32") ) buffer_dtype = PrecisionType.to_dtype( mixed_precision_config.get("buffer_dtype", "fp32") ) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 mixed_precision = MixedPrecision( param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype ) auto_wrap_policy = get_fsdp_wrap_policy( module=actor_module, config=fsdp_config.get("wrap_policy", None), is_lora=self.config.model.get("lora_rank", 0) > 0, ) if self.rank == 0: print(f"wrap_policy: {auto_wrap_policy}") fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) # TODO: add transformer policy # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) fsdp_strategy = self.config.actor.strategy if fsdp_strategy == "fsdp": actor_module_fsdp = FSDP( actor_module, cpu_offload=cpu_offload, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, device_mesh=self.device_mesh, forward_prefetch=self.config.actor.fsdp_config.forward_prefetch, ) elif fsdp_strategy == "fsdp2": assert ( CPUOffloadPolicy is not None ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True ) if role == "actor" and fsdp_config.offload_policy: cpu_offload = CPUOffloadPolicy(pin_memory=True) self._is_offload_param = False self._is_offload_optimizer = False else: cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) fsdp_kwargs = { "mesh": fsdp_mesh, "mp_policy": mp_policy, "offload_policy": cpu_offload, "reshard_after_forward": fsdp_config.reshard_after_forward, } full_state = actor_module.state_dict() apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) actor_module_fsdp = actor_module else: raise NotImplementedError(f"not implement {fsdp_strategy}") if enable_activation_offload: enable_activation_offloading( actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing ) log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) # TODO: add more optimizer args into config if role == "actor" and optim_config is not None: from verl.utils.torch_functional import ( get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, ) actor_optimizer = optim.AdamW( actor_module_fsdp.parameters(), lr=optim_config.lr, betas=optim_config.get("betas", (0.9, 0.999)), weight_decay=optim_config.get("weight_decay", 1e-2), ) total_steps = optim_config.get("total_training_steps", 0) num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) warmup_style = optim_config.get("warmup_style", "constant") min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) num_cycles = optim_config.get("num_cycles", 0.5) if num_warmup_steps < 0: num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") if warmup_style == "constant": actor_lr_scheduler = get_constant_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps ) elif warmup_style == "cosine": actor_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles, ) else: raise NotImplementedError(f"Warmup style {warmup_style} is not supported") log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) else: actor_optimizer = None actor_lr_scheduler = None return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from trinity.trainer.verl.dp_actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) override_model_config = OmegaConf.to_container( self.config.model.get("override_config", OmegaConf.create()) ) use_remove_padding = self.config.model.get("use_remove_padding", False) use_shm = self.config.model.get("use_shm", False) use_fused_kernels = self.config.model.get("use_fused_kernels", False) if self._is_actor: # we need the model for actor and rollout optim_config = self.config.actor.optim fsdp_config = self.config.actor.fsdp_config local_path = copy_to_local(self.config.model.path, use_shm=use_shm) ( self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config, ) = self._build_model_optimizer( model_path=local_path, fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, enable_gradient_checkpointing=self.config.model.get( "enable_gradient_checkpointing", False ), trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", enable_activation_offload=self.config.model.get("enable_activation_offload", False), ) # get the original unwrapped module if fsdp_version(self.actor_module_fsdp) == 1: self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After offload actor model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) if self._is_actor: OmegaConf.set_struct(self.config.actor, True) with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding self.config.actor.use_fused_kernels = use_fused_kernels self.actor = DataParallelPPOActor( config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer, ) if self._is_ref: local_path = copy_to_local(self.config.model.path, use_shm=use_shm) self.ref_module_fsdp = self._build_model_optimizer( model_path=local_path, fsdp_config=self.config.ref.fsdp_config, optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", )[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels self.ref_policy = DataParallelPPOActor( config=self.config.ref, actor_module=self.ref_module_fsdp ) self.checkpoint_manager = FSDPCheckpointManager( model=self.ref_module_fsdp, optimizer=None, lr_scheduler=None, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.ref.checkpoint, ) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.actor.checkpoint, config=self.config.synchronizer, )
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def setup_weight_sync_group(self): if self.config.synchronizer.sync_method == SyncMethod.NCCL: model = self.actor_module_fsdp self.named_modules = [] self.state_dict_meta = [] with self._fsdp_offload_context(): if self.config.actor.strategy == "fsdp": for name, module in model.named_modules(): if isinstance(module, FSDP): self.named_modules.append((name, module)) for name_prefix, module in self.named_modules: with FSDP.summon_full_params(module, recurse=False): for name, param in module.named_parameters(): if isinstance(param, FlatParameter): continue realname = ( name_prefix[len(FSDP_PREFIX) :] + "." + name if name_prefix else name ) self.state_dict_meta.append( (realname, str(param.dtype), tuple(param.shape)) ) param = None torch.cuda.empty_cache() else: # fsdp2 for name, param in model.named_parameters(): self.state_dict_meta.append((name, str(param.dtype), tuple(param.shape))) if torch.distributed.get_rank() == 0: import ray master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") synchronizer = Synchronizer.get_actor( namespace=self.config.synchronizer.ray_namespace ) setup_ref = synchronizer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) timeout = self.config.synchronizer.sync_timeout self._model_update_group = init_process_group( host=master_address, port=master_port, group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, backend="nccl", timeout=timeout, world_size=world_size, device_id=torch.device(f"cuda:{get_device_id()}"), rank=0, ) torch.distributed.barrier(group=self._model_update_group) ray.get(setup_ref)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def sync_weight(self): with self._fsdp_offload_context(): if self.config.actor.strategy == "fsdp": for name_prefix, module in self.named_modules: with FSDP.summon_full_params(module, recurse=False): if torch.distributed.get_rank() == 0: for name, param in module.named_parameters(): if isinstance(param, FlatParameter): continue torch.distributed.broadcast( param, 0, group=self._model_update_group ) param = None else: # fsdp2 for name, param in self.actor_module_fsdp.named_parameters(): full_param = param.full_tensor().detach().to(device=get_device_id()) if torch.distributed.get_rank() == 0: torch.distributed.broadcast(full_param, 0, group=self._model_update_group) del full_param if torch.distributed.get_rank() == 0: torch.distributed.barrier(group=self._model_update_group) torch.cuda.synchronize()
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def upload_state_dict(self, trainer_step: int): with self._fsdp_offload_context(): self.checkpoint_manager.upload_state_dict(trainer_step)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def set_algorithm(self, algo_config: AlgorithmConfig): self.actor.set_algorithm(algo_config)
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares data = data.to("cpu") # data will to device with each micro batch on actor.update_policy assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) # perform training with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(data=data) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops( global_num_tokens, delta_time ) metrics["perf/mfu/actor"] = ( estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size ) metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / ( 1024**3 ) metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / ( 1024**3 ) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr self.actor_lr_scheduler.step() # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) return output
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_log_prob(self, data: DataProto): # when is_lora is True, we use the actor without lora applied to calculate the log_prob # which is mostly used for ref log_prob calculation assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares from contextlib import nullcontext is_lora = data.meta_info.pop("is_lora", False) adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() data = data.to(get_device_id()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) with adapter_ctx: output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) output = DataProto.from_dict( tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}, ) output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: self.actor.actor_module._handle.reshard(True) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) return output
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): if self._is_lora: # if _is_lora, actor without lora applied is the ref data.meta_info["is_lora"] = True data = self.compute_log_prob(data) # this old_log_probs is in fact ref_log_prob data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]}) return data assert self._is_ref # else: # otherwise, the class have a standalone ref model # Support all hardwares data = data.to(get_device_id()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["temperature"] = self.config.rollout.temperature data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1: self.ref_policy.actor_module._handle.reshard(True) return output
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint( self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None, model_state_dict_only=False, ): from verl.utils.logger import log_with_rank # only support save and load ckpt for actor assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, model_state_dict_only=model_state_dict_only, ) dist.barrier() if self._is_lora and hasattr( getattr(self, "actor_module", self.actor_module_fsdp), "peft_config" ): lora_save_path = os.path.join(local_path, "lora_adapter") peft_model = getattr(self, "actor_module", self.actor_module_fsdp) peft_config = {} if dist.get_rank() == 0: os.makedirs(lora_save_path, exist_ok=True) peft_config = asdict(peft_model.peft_config.get("default", {})) peft_config["task_type"] = peft_config["task_type"].value peft_config["peft_type"] = peft_config["peft_type"].value peft_config["target_modules"] = list(peft_config["target_modules"]) try: if fsdp_version(self.actor_module_fsdp) > 0: self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name()) lora_params = layered_summon_lora_params(self.actor_module_fsdp) if dist.get_rank() == 0: save_file( lora_params, os.path.join(lora_save_path, "adapter_model.safetensors") ) with open( os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8", ) as f: json.dump(peft_config, f, ensure_ascii=False, indent=4) except Exception as e: log_with_rank( f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True, ) dist.barrier() log_with_rank( f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}", rank=dist.get_rank(), logger=logger, log_only_rank_0=True, ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): if self._is_actor and self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) if self._is_actor and self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_actor and self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def clear_optimizer_state(self): print("Clear actor optimizer state") if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) self.actor_optimizer.state.clear() self.actor_optimizer.zero_grad() if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def wait_on_save_thread(self) -> None: self.checkpoint_manager.wait_on_save_thread()
[docs] class CriticWorker(Worker):
[docs] def __init__(self, config): super().__init__() import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group( backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None), timeout=timedelta(seconds=self.config.synchronizer.sync_timeout), ) self.config = config # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"], ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) # set FSDP offload params self._is_offload_param = self.config.model.fsdp_config.param_offload self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config # note: no need to conduct `ppo_mini_batch_size *= rollout_n` anymore self.config.ppo_mini_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) if self.config.ppo_micro_batch_size is not None: self.config.ppo_micro_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) self.config.forward_micro_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size if self.config.ppo_micro_batch_size_per_gpu is not None: assert ( self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" assert ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0 ), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" self._is_lora = self.config.model.get("lora_rank", 0) > 0
def _build_critic_model_optimizer(self, config): # noqa: C901 # the following line is necessary from torch import optim from torch.distributed.fsdp import MixedPrecision from verl.utils.model import load_valuehead_model, print_model_size from verl.utils.torch_dtypes import PrecisionType use_shm = config.model.get("use_shm", False) local_path = copy_to_local(config.model.path, use_shm=use_shm) # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info # using random initialized model from any architecture. May not be the same as Actor. tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) self.tokenizer = hf_tokenizer( tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False) ) self.processor = hf_processor( tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False) ) if self.config.model.get("custom_chat_template", None) is not None: if self.processor is not None: self.processor.chat_template = self.config.model.custom_chat_template else: self.tokenizer.chat_template = self.config.model.custom_chat_template from omegaconf import OmegaConf override_config = OmegaConf.to_container( self.config.model.get("override_config", OmegaConf.create()) ) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_config) if self.rank == 0: print(f"Critic overriding config {override_config_kwargs}") torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig critic_model_config = AutoConfig.from_pretrained( local_path, attn_implementation="flash_attention_2", trust_remote_code=config.model.get("trust_remote_code", False), ) critic_model_config.num_labels = 1 # patch for kimi-vl if getattr(critic_model_config, "model_type", None) == "kimi_vl": critic_model_config.text_config.topk_method = "greedy" init_context = get_init_weight_context_manager( use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") critic_model_config.classifier_dropout = 0.0 critic_model_config.hidden_dropout = "0" critic_model_config.summary_dropout_prob = 0.0 critic_module = load_valuehead_model( local_path, torch_dtype, critic_model_config, config.model.get("trust_remote_code", False), ) use_remove_padding = config.model.get("use_remove_padding", False) apply_monkey_patch( model=critic_module, use_remove_padding=use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, ) # some parameters may not in torch_dtype critic_module.to(torch_dtype) if config.model.get("enable_gradient_checkpointing", False): critic_module.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) if self._is_lora: print("Applying LoRA to critic module") critic_module.enable_input_require_grads() # Convert config to regular Python types before creating PEFT model lora_config = { "task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, "target_modules": convert_to_regular_types(self.config.model.target_modules), "bias": "none", } critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) if self.rank == 0: print_model_size(critic_module) self.critic_model_config = critic_model_config fsdp_config = self.config.model.fsdp_config mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) reduce_dtype = PrecisionType.to_dtype( mixed_precision_config.get("reduce_dtype", "fp32") ) buffer_dtype = PrecisionType.to_dtype( mixed_precision_config.get("buffer_dtype", "fp32") ) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 mixed_precision = MixedPrecision( param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype ) auto_wrap_policy = get_fsdp_wrap_policy( module=critic_module, config=self.config.model.fsdp_config.wrap_policy, is_lora=self.config.model.get("lora_rank", 0) > 0, ) log_gpu_memory_usage("Before critic FSDP", logger=None) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation if config.strategy == "fsdp": critic_module = FSDP( critic_module, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, forward_prefetch=self.config.model.fsdp_config.forward_prefetch, device_mesh=self.device_mesh, cpu_offload=None, ) elif config.strategy == "fsdp2": assert ( CPUOffloadPolicy is not None ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True ) offload_policy = None if fsdp_config.offload_policy: self._is_offload_param = False self._is_offload_optimizer = False offload_policy = CPUOffloadPolicy(pin_memory=True) fsdp_kwargs = { "mesh": fsdp_mesh, "mp_policy": mp_policy, "offload_policy": offload_policy, "reshard_after_forward": fsdp_config.reshard_after_forward, } full_state = critic_module.state_dict() apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) else: raise NotImplementedError(f"Unknown strategy {config.strategy}") if config.model.get("enable_activation_offload", False): enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) enable_activation_offloading( critic_module, config.strategy, enable_gradient_checkpointing ) log_gpu_memory_usage("After critic FSDP", logger=None) critic_optimizer = optim.AdamW( critic_module.parameters(), lr=config.optim.lr, betas=config.optim.get("betas", (0.9, 0.999)), weight_decay=config.optim.get("weight_decay", 1e-2), ) total_steps = config.optim.get("total_training_steps", 0) num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) warmup_style = config.optim.get("warmup_style", "constant") if num_warmup_steps < 0: num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") from verl.utils.torch_functional import ( get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, ) if warmup_style == "constant": critic_lr_scheduler = get_constant_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps ) elif warmup_style == "cosine": critic_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, ) else: raise NotImplementedError(f"Warmup style {warmup_style} is not supported") return critic_module, critic_optimizer, critic_lr_scheduler
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) from verl.workers.critic import DataParallelPPOCritic ( self.critic_module, self.critic_optimizer, self.critic_lr_scheduler, ) = self._build_critic_model_optimizer(self.config) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) log_gpu_memory_usage("After offload critic model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) self.critic = DataParallelPPOCritic( config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer, ) self.flops_counter = FlopsCounter(self.critic_model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.critic_module, optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.checkpoint, )
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): # Support all hardwares data = data.to(get_device_id()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) micro_batch_size = self.config.forward_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) values = self.critic.compute_values(data=data) output = DataProto.from_dict(tensors={"values": values}) output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) return output
[docs] @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): # Support all hardwares data = data.to(get_device_id()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) with Timer(name="update_critic", logger=None) as timer: metrics = self.critic.update_critic(data=data) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops( global_num_tokens, delta_time ) metrics["perf/mfu/critic"] = ( estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size ) self.critic_lr_scheduler.step() lr = self.critic_lr_scheduler.get_last_lr()[0] metrics["critic/lr"] = lr output = DataProto(batch=None, meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) output = output.to("cpu") return output
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def clear_optimizer_state(self): print("Clear critic optimizer state") if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) self.critic_optimizer.state.clear() self.critic_optimizer.zero_grad() if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer)
[docs] @register(dispatch_mode=Dispatch.ONE_TO_ALL) def wait_on_save_thread(self) -> None: self.checkpoint_manager.wait_on_save_thread()