# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm.
Modified from https://github.com/volcengine/verl/blob/0758489422e8d41a89e6c36d4c477714520f0dcc/verl/workers/fsdp_workers.py
"""
import json
import logging
import os
import warnings
from dataclasses import asdict
from datetime import timedelta
import psutil
import torch
import torch.distributed
import torch.distributed as dist
import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically.
from codetiming import Timer
from omegaconf import DictConfig, open_dict
from peft import LoraConfig, TaskType, get_peft_model
from safetensors.torch import save_file
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FlatParameter
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX
from verl import DataProto
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import get_torch_device, is_cuda_available
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
apply_fsdp2,
fsdp2_load_full_state_dict,
fsdp_version,
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
layered_summon_lora_params,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.py_functional import convert_to_regular_types
from verl.workers.fsdp_workers import (
create_device_mesh,
device_name,
get_sharding_strategy,
)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from trinity.common.config import AlgorithmConfig
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
from trinity.utils.distributed import init_process_group, is_ipv6_address
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
[docs]
class ActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
[docs]
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl" if is_cuda_available else "cpu:gloo,npu:hccl",
rank=rank,
world_size=world_size,
timeout=timedelta(seconds=self.config.synchronizer.sync_timeout),
)
# build device mesh for FSDP
world_size = torch.distributed.get_world_size()
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = create_device_mesh(
world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size
)
# build device mesh for Ulysses Sequence Parallel
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.actor.get(
"ulysses_sequence_parallel_size", 1
)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
device_name,
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self._lora_rank = self.config.model.get("lora_rank", 0)
self._is_lora = self._lora_rank > 0
self.role = role
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._is_offload_param = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get(
"optimizer_offload", False
)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False)
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
assert (
self.config.actor.ppo_mini_batch_size > 0
), f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization"
# micro bsz
if self.config.actor.ppo_micro_batch_size is not None:
self.config.actor.ppo_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.actor.ppo_micro_batch_size_per_gpu = (
self.config.actor.ppo_micro_batch_size
)
if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
assert (
self.config.actor.ppo_mini_batch_size
% self.config.actor.ppo_micro_batch_size_per_gpu
== 0
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
assert (
self.config.actor.ppo_mini_batch_size
// self.config.actor.ppo_micro_batch_size_per_gpu
> 0
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.ref.log_prob_micro_batch_size_per_gpu = (
self.config.ref.log_prob_micro_batch_size
)
def _build_model_optimizer( # noqa: C901
self,
model_path,
fsdp_config,
optim_config,
override_model_config,
use_remove_padding=False,
use_fused_kernels=False,
enable_gradient_checkpointing=False,
trust_remote_code=False,
use_liger=False,
role="actor",
enable_activation_offload=False,
):
from torch import optim
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
)
from verl.utils.model import (
get_generation_config,
print_model_size,
update_model_config,
)
from verl.utils.torch_dtypes import PrecisionType
assert role in ["actor", "ref"]
log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger)
local_path = model_path
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
torch_dtype = fsdp_config.get("model_dtype", None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(
local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
)
# patch for kimi-vl
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
actor_model_config.text_config.topk_method = "greedy"
self.generation_config = get_generation_config(
local_path, trust_remote_code=trust_remote_code
)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f"Model config after override: {actor_model_config}")
# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
init_context = get_init_weight_context_manager(
use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
actor_module_class = AutoModelForVision2Seq
else:
actor_module_class = AutoModelForCausalLM
actor_module = actor_module_class.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
trust_remote_code=trust_remote_code,
)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
from liger_kernel.transformers.monkey_patch import (
_apply_liger_kernel_to_instance,
)
_apply_liger_kernel_to_instance(model=actor_module)
apply_monkey_patch(
model=actor_module,
use_remove_padding=use_remove_padding,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
use_fused_kernels=use_fused_kernels,
)
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
if self._is_lora:
print("Applying LoRA to actor module")
actor_module.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
"task_type": TaskType.CAUSAL_LM,
"r": self.config.model.lora_rank,
"lora_alpha": self.config.model.lora_alpha,
"target_modules": convert_to_regular_types(self.config.model.target_modules),
"bias": "none",
}
actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("reduce_dtype", "fp32")
)
buffer_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("buffer_dtype", "fp32")
)
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype
)
auto_wrap_policy = get_fsdp_wrap_policy(
module=actor_module,
config=fsdp_config.get("wrap_policy", None),
is_lora=self.config.model.get("lora_rank", 0) > 0,
)
if self.rank == 0:
print(f"wrap_policy: {auto_wrap_policy}")
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# TODO: add transformer policy
# We force reference policy to use CPUOffload to save memory.
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
fsdp_strategy = self.config.actor.strategy
if fsdp_strategy == "fsdp":
actor_module_fsdp = FSDP(
actor_module,
cpu_offload=cpu_offload,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_torch_device().current_device(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh,
forward_prefetch=False,
)
elif fsdp_strategy == "fsdp2":
assert (
CPUOffloadPolicy is not None
), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
)
if role == "actor" and fsdp_config.offload_policy:
cpu_offload = CPUOffloadPolicy(pin_memory=True)
self._is_offload_param = False
self._is_offload_optimizer = False
else:
cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": cpu_offload,
"reshard_after_forward": fsdp_config.reshard_after_forward,
}
full_state = actor_module.state_dict()
apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
actor_module_fsdp = actor_module
else:
raise NotImplementedError(f"not implement {fsdp_strategy}")
if enable_activation_offload:
enable_activation_offloading(
actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing
)
log_gpu_memory_usage(f"After {role} FSDP init", logger=logger)
# TODO: add more optimizer args into config
if role == "actor" and optim_config is not None:
from verl.utils.torch_functional import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
)
actor_optimizer = optim.AdamW(
actor_module_fsdp.parameters(),
lr=optim_config.lr,
betas=optim_config.get("betas", (0.9, 0.999)),
weight_decay=optim_config.get("weight_decay", 1e-2),
)
total_steps = optim_config.get("total_training_steps", 0)
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
warmup_style = optim_config.get("warmup_style", "constant")
min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
num_cycles = optim_config.get("num_cycles", 0.5)
if num_warmup_steps < 0:
num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if warmup_style == "constant":
actor_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
actor_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
else:
actor_optimizer = None
actor_lr_scheduler = None
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from trinity.trainer.verl.dp_actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(
self.config.model.get("override_config", OmegaConf.create())
)
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_shm = self.config.model.get("use_shm", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
if self._is_actor:
# we need the model for actor and rollout
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
(
self.actor_module_fsdp,
self.actor_optimizer,
self.actor_lr_scheduler,
self.actor_model_config,
) = self._build_model_optimizer(
model_path=local_path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get(
"enable_gradient_checkpointing", False
),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
enable_activation_offload=self.config.model.get("enable_activation_offload", False),
)
# get the original unwrapped module
if fsdp_version(self.actor_module_fsdp) == 1:
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.config.actor.use_fused_kernels = use_fused_kernels
self.actor = DataParallelPPOActor(
config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer,
)
if self._is_ref:
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
self.ref_module_fsdp = self._build_model_optimizer(
model_path=local_path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
)[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.config.ref.use_fused_kernels = use_fused_kernels
self.ref_policy = DataParallelPPOActor(
config=self.config.ref, actor_module=self.ref_module_fsdp
)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.ref_module_fsdp,
optimizer=None,
lr_scheduler=None,
processing_class=self.processor if self.processor is not None else self.tokenizer,
)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.actor.checkpoint.contents,
)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def setup_weight_sync_group(self):
if (
hasattr(self.config, "synchronizer")
and getattr(self.config.synchronizer, "sync_method", None) == SyncMethod.NCCL
):
model = self.actor_module_fsdp
self.named_modules = []
self.state_dict_meta = []
for name, module in model.named_modules():
if isinstance(module, FSDP):
self.named_modules.append((name, module))
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
realname = (
name_prefix[len(FSDP_PREFIX) :] + "." + name if name_prefix else name
)
self.state_dict_meta.append(
(realname, str(param.dtype), tuple(param.shape))
)
param = None
torch.cuda.empty_cache()
if torch.distributed.get_rank() == 0:
import ray
master_address, master_port = self.get_availale_master_addr_port()
world_size = self.config.synchronizer.explorer_world_size + 1
print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).")
explorer = ray.get_actor(self.config.explorer_name)
setup_ref = explorer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
if is_ipv6_address(master_address):
# using tcp://ipv6:port will lead to ValueError
init_method = f"tcp://[{master_address}]:{master_port}"
else:
init_method = f"tcp://{master_address}:{master_port}"
timeout = self.config.synchronizer.sync_timeout
self._model_update_group = init_process_group(
backend="nccl",
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=0,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
)
ray.get(setup_ref)
torch.distributed.barrier()
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def sync_weight(self):
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
if torch.distributed.get_rank() == 0:
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
torch.distributed.broadcast(param, 0, group=self._model_update_group)
param = None
torch.distributed.barrier()
torch.cuda.synchronize()
torch.cuda.empty_cache()
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_algorithm(self, algo_config: AlgorithmConfig):
self.actor.set_algorithm(algo_config)
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
# Support all hardwares
data = data.to(get_torch_device().current_device())
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(
optimizer=self.actor_optimizer, device_id=get_torch_device().current_device()
)
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
# perform training
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
global_num_tokens, delta_time
)
metrics["perf/mfu/actor"] = (
estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
)
metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (
1024**3
)
metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (
1024**3
)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
self.actor_lr_scheduler.step()
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during update_actor", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
# when is_lora is True, we use the actor without lora applied to calculate the log_prob
# which is mostly used for ref log_prob calculation
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
# Support all hardwares
from contextlib import nullcontext
is_lora = data.meta_info.pop("is_lora", False)
adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext()
data = data.to(get_torch_device().current_device())
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
with adapter_ctx:
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
output = DataProto.from_dict(
tensors={"old_log_probs": output, "entropys": entropys},
meta_info={"temperature": self.config.rollout.temperature},
)
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
if self._is_lora:
# if _is_lora, actor without lora applied is the ref
data.meta_info["is_lora"] = True
data = self.compute_log_prob(data)
# this old_log_probs is in fact ref_log_prob
data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]})
return data
assert self._is_ref
# else:
# otherwise, the class have a standalone ref model
# Support all hardwares
data = data.to(get_torch_device().current_device())
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["temperature"] = self.config.rollout.temperature
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1:
self.ref_policy.actor_module._handle.reshard(True)
return output
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
# only support save and load ckpt for actor
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.save_checkpoint(
local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep,
)
dist.barrier()
if self._is_lora and hasattr(
getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"
):
lora_save_path = os.path.join(local_path, "lora_adapter")
peft_model = getattr(self, "actor_module", self.actor_module_fsdp)
peft_config = {}
if dist.get_rank() == 0:
os.makedirs(lora_save_path, exist_ok=True)
peft_config = asdict(peft_model.peft_config.get("default", {}))
peft_config["task_type"] = peft_config["task_type"].value
peft_config["peft_type"] = peft_config["peft_type"].value
peft_config["target_modules"] = list(peft_config["target_modules"])
try:
if fsdp_version(self.actor_module_fsdp) > 0:
self.actor_module_fsdp = self.actor_module_fsdp.cuda()
lora_params = layered_summon_lora_params(self.actor_module_fsdp)
if dist.get_rank() == 0:
save_file(
lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")
)
with open(
os.path.join(lora_save_path, "adapter_config.json"),
"w",
encoding="utf-8",
) as f:
json.dump(peft_config, f, ensure_ascii=False, indent=4)
except Exception as e:
if dist.get_rank() == 0:
print(f"[rank-{self.rank}]: Save LoRA Adapter Error ({e})")
dist.barrier()
if dist.get_rank() == 0:
print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
if self._is_actor and self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
if self._is_actor and self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_actor and self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_optimizer_state(self):
print("Clear actor optimizer state")
if self._is_offload_optimizer:
load_fsdp_optimizer(
optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()
)
self.actor_optimizer.state.clear()
self.actor_optimizer.zero_grad()
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
[docs]
class CriticWorker(Worker):
[docs]
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if is_cuda_available else "hccl",
timeout=timedelta(seconds=self.config.synchronizer.sync_timeout),
)
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
device_name,
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# set FSDP offload params
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.ppo_mini_batch_size *= self.config.rollout_n
self.config.ppo_mini_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
if self.config.ppo_micro_batch_size is not None:
self.config.ppo_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.forward_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
if self.config.ppo_micro_batch_size_per_gpu is not None:
assert (
self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
assert (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
self._is_lora = self.config.model.get("lora_rank", 0) > 0
def _build_critic_model_optimizer(self, config):
# the following line is necessary
from torch import optim
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import print_model_size
from verl.utils.torch_dtypes import PrecisionType
use_shm = config.model.get("use_shm", False)
local_path = copy_to_local(config.model.path, use_shm=use_shm)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)
self.tokenizer = hf_tokenizer(
tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
self.processor = hf_processor(
tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(
self.config.model.get("override_config", OmegaConf.create())
)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f"Critic overriding config {override_config_kwargs}")
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig, AutoModelForTokenClassification
critic_model_config = AutoConfig.from_pretrained(
local_path,
attn_implementation="flash_attention_2",
trust_remote_code=config.model.get("trust_remote_code", False),
)
critic_model_config.num_labels = 1
# patch for kimi-vl
if getattr(critic_model_config, "model_type", None) == "kimi_vl":
critic_model_config.text_config.topk_method = "greedy"
init_context = get_init_weight_context_manager(
use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
critic_model_config.classifier_dropout = 0.0
critic_model_config.hidden_dropout = "0"
critic_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=critic_model_config,
trust_remote_code=config.model.get("trust_remote_code", False),
)
use_remove_padding = config.model.get("use_remove_padding", False)
apply_monkey_patch(
model=critic_module,
use_remove_padding=use_remove_padding,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
)
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
if config.model.get("enable_gradient_checkpointing", False):
critic_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
if self._is_lora:
print("Applying LoRA to critic module")
critic_module.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
"task_type": TaskType.CAUSAL_LM,
"r": self.config.model.lora_rank,
"lora_alpha": self.config.model.lora_alpha,
"target_modules": convert_to_regular_types(self.config.model.target_modules),
"bias": "none",
}
critic_module = get_peft_model(critic_module, LoraConfig(**lora_config))
if self.rank == 0:
print_model_size(critic_module)
self.critic_model_config = critic_model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("reduce_dtype", "fp32")
)
buffer_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("buffer_dtype", "fp32")
)
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype
)
auto_wrap_policy = get_fsdp_wrap_policy(
module=critic_module,
config=self.config.model.fsdp_config.wrap_policy,
is_lora=self.config.model.get("lora_rank", 0) > 0,
)
log_gpu_memory_usage("Before critic FSDP", logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation
if config.strategy == "fsdp":
critic_module = FSDP(
critic_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_torch_device().current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None,
)
elif config.strategy == "fsdp2":
assert (
CPUOffloadPolicy is not None
), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
)
offload_policy = None
if fsdp_config.offload_policy:
self._is_offload_param = False
self._is_offload_optimizer = False
offload_policy = CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": offload_policy,
"reshard_after_forward": fsdp_config.reshard_after_forward,
}
full_state = critic_module.state_dict()
apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)
else:
raise NotImplementedError(f"Unknown strategy {config.strategy}")
if config.model.get("enable_activation_offload", False):
enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False)
enable_activation_offloading(
critic_module, config.strategy, enable_gradient_checkpointing
)
log_gpu_memory_usage("After critic FSDP", logger=None)
critic_optimizer = optim.AdamW(
critic_module.parameters(),
lr=config.optim.lr,
betas=config.optim.get("betas", (0.9, 0.999)),
weight_decay=config.optim.get("weight_decay", 1e-2),
)
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
warmup_style = config.optim.get("warmup_style", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
from verl.utils.torch_functional import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
)
if warmup_style == "constant":
critic_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
critic_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=critic_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
return critic_module, critic_optimizer, critic_lr_scheduler
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from verl.workers.critic import DataParallelPPOCritic
(
self.critic_module,
self.critic_optimizer,
self.critic_lr_scheduler,
) = self._build_critic_model_optimizer(self.config)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
log_gpu_memory_usage("After offload critic model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
log_gpu_memory_usage("After offload critic optimizer during init", logger=logger)
self.critic = DataParallelPPOCritic(
config=self.config,
critic_module=self.critic_module,
critic_optimizer=self.critic_optimizer,
)
self.flops_counter = FlopsCounter(self.critic_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.critic_module,
optimizer=self.critic_optimizer,
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.checkpoint.contents,
)
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
# Support all hardwares
data = data.to(get_torch_device().current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
micro_batch_size = self.config.forward_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
# Support all hardwares
data = data.to(get_torch_device().current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
if self._is_offload_optimizer:
load_fsdp_optimizer(
optimizer=self.critic_optimizer, device_id=get_torch_device().current_device()
)
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_critic", logger=None) as timer:
metrics = self.critic.update_critic(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
global_num_tokens, delta_time
)
metrics["perf/mfu/critic"] = (
estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
)
self.critic_lr_scheduler.step()
lr = self.critic_lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr
output = DataProto(batch=None, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
output = output.to("cpu")
return output
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.save_checkpoint(
local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep,
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_optimizer_state(self):
print("Clear critic optimizer state")
if self._is_offload_optimizer:
load_fsdp_optimizer(
optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()
)
self.critic_optimizer.state.clear()
self.critic_optimizer.zero_grad()
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)