# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm.
Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/workers/fsdp_workers.py
"""
import json
import logging
import os
import warnings
from contextlib import contextmanager
from dataclasses import asdict
from datetime import timedelta
import psutil
import torch
import torch.distributed
import torch.distributed as dist
import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically.
from codetiming import Timer
from omegaconf import DictConfig, OmegaConf, open_dict
from peft import LoraConfig, TaskType, get_peft_model
from safetensors.torch import save_file
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FlatParameter
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX
from verl import DataProto
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import (
get_device_id,
get_device_name,
get_nccl_backend,
get_torch_device,
)
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
apply_fsdp2,
fsdp2_load_full_state_dict,
fsdp_version,
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
layered_summon_lora_params,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.py_functional import convert_to_regular_types
from verl.workers.fsdp_workers import (
create_device_mesh,
device_name,
get_sharding_strategy,
)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from trinity.common.config import AlgorithmConfig
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
from trinity.manager.synchronizer import Synchronizer
from trinity.trainer.verl.fsdp_checkpoint_manager import FSDPCheckpointManager
from trinity.utils.distributed import init_process_group
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
[docs]
class ActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
[docs]
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.distributed.init_process_group(
backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}",
rank=rank,
world_size=world_size,
init_method=os.environ.get("DIST_INIT_METHOD", None),
timeout=timedelta(seconds=self.config.synchronizer.sync_timeout),
)
# build device mesh for FSDP
world_size = torch.distributed.get_world_size()
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = create_device_mesh(
world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size
)
# build device mesh for Ulysses Sequence Parallel
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.actor.get(
"ulysses_sequence_parallel_size", 1
)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
device_name,
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self._lora_rank = self.config.model.get("lora_rank", 0)
self._is_lora = self._lora_rank > 0
self.role = role
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._is_offload_param = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get(
"optimizer_offload", False
)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False)
# normalize config
if self._is_actor:
# note: no need to conduct `ppo_mini_batch_size *= rollout_n` anymore
self.config.actor.ppo_mini_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
assert (
self.config.actor.ppo_mini_batch_size > 0
), f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization"
# micro bsz
if self.config.actor.ppo_micro_batch_size is not None:
self.config.actor.ppo_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.actor.ppo_micro_batch_size_per_gpu = (
self.config.actor.ppo_micro_batch_size
)
if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
assert (
self.config.actor.ppo_mini_batch_size
% self.config.actor.ppo_micro_batch_size_per_gpu
== 0
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
assert (
self.config.actor.ppo_mini_batch_size
// self.config.actor.ppo_micro_batch_size_per_gpu
> 0
), f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.ref.log_prob_micro_batch_size_per_gpu = (
self.config.ref.log_prob_micro_batch_size
)
@contextmanager
def _fsdp_offload_context(self):
"""A context manager to handle FSDP model GPU loading and CPU offloading."""
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
try:
yield
finally:
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
torch.distributed.barrier()
torch.cuda.empty_cache()
def _build_model_optimizer( # noqa: C901
self,
model_path,
fsdp_config,
optim_config,
override_model_config,
use_remove_padding=False,
use_fused_kernels=False,
enable_gradient_checkpointing=False,
trust_remote_code=False,
use_liger=False,
role="actor",
enable_activation_offload=False,
):
from torch import optim
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
)
from verl.utils.model import (
get_generation_config,
print_model_size,
update_model_config,
)
from verl.utils.torch_dtypes import PrecisionType
assert role in ["actor", "ref"]
log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger)
local_path = model_path
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
if self.config.model.get("custom_chat_template", None) is not None:
if self.processor is not None:
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
torch_dtype = fsdp_config.get("model_dtype", None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(
local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
)
# patch for kimi-vl
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
actor_model_config.text_config.topk_method = "greedy"
self.generation_config = get_generation_config(
local_path, trust_remote_code=trust_remote_code
)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f"Model config after override: {actor_model_config}")
# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
init_context = get_init_weight_context_manager(
use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
actor_module_class = AutoModelForVision2Seq
else:
actor_module_class = AutoModelForCausalLM
actor_module = actor_module_class.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
trust_remote_code=trust_remote_code,
)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
from liger_kernel.transformers.monkey_patch import (
_apply_liger_kernel_to_instance,
)
_apply_liger_kernel_to_instance(model=actor_module)
fused_kernel_options = self.config.model.get("fused_kernel_options", None)
fused_kernels_backend = (
fused_kernel_options.get("impl_backend", None)
if fused_kernel_options is not None
else None
)
apply_monkey_patch(
model=actor_module,
use_remove_padding=use_remove_padding,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
use_fused_kernels=use_fused_kernels,
fused_kernels_backend=fused_kernels_backend,
)
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
if self._is_lora:
print("Applying LoRA to actor module")
actor_module.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
"task_type": TaskType.CAUSAL_LM,
"r": self.config.model.lora_rank,
"lora_alpha": self.config.model.lora_alpha,
"target_modules": convert_to_regular_types(self.config.model.target_modules),
"bias": "none",
}
actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("reduce_dtype", "fp32")
)
buffer_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("buffer_dtype", "fp32")
)
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype
)
auto_wrap_policy = get_fsdp_wrap_policy(
module=actor_module,
config=fsdp_config.get("wrap_policy", None),
is_lora=self.config.model.get("lora_rank", 0) > 0,
)
if self.rank == 0:
print(f"wrap_policy: {auto_wrap_policy}")
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# TODO: add transformer policy
# We force reference policy to use CPUOffload to save memory.
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
fsdp_strategy = self.config.actor.strategy
if fsdp_strategy == "fsdp":
actor_module_fsdp = FSDP(
actor_module,
cpu_offload=cpu_offload,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh,
forward_prefetch=self.config.actor.fsdp_config.forward_prefetch,
)
elif fsdp_strategy == "fsdp2":
assert (
CPUOffloadPolicy is not None
), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
)
if role == "actor" and fsdp_config.offload_policy:
cpu_offload = CPUOffloadPolicy(pin_memory=True)
self._is_offload_param = False
self._is_offload_optimizer = False
else:
cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": cpu_offload,
"reshard_after_forward": fsdp_config.reshard_after_forward,
}
full_state = actor_module.state_dict()
apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
actor_module_fsdp = actor_module
else:
raise NotImplementedError(f"not implement {fsdp_strategy}")
if enable_activation_offload:
enable_activation_offloading(
actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing
)
log_gpu_memory_usage(f"After {role} FSDP init", logger=logger)
# TODO: add more optimizer args into config
if role == "actor" and optim_config is not None:
from verl.utils.torch_functional import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
)
actor_optimizer = optim.AdamW(
actor_module_fsdp.parameters(),
lr=optim_config.lr,
betas=optim_config.get("betas", (0.9, 0.999)),
weight_decay=optim_config.get("weight_decay", 1e-2),
)
total_steps = optim_config.get("total_training_steps", 0)
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
warmup_style = optim_config.get("warmup_style", "constant")
min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
num_cycles = optim_config.get("num_cycles", 0.5)
if num_warmup_steps < 0:
num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if warmup_style == "constant":
actor_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
actor_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
else:
actor_optimizer = None
actor_lr_scheduler = None
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from trinity.trainer.verl.dp_actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
override_model_config = OmegaConf.to_container(
self.config.model.get("override_config", OmegaConf.create())
)
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_shm = self.config.model.get("use_shm", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
if self._is_actor:
# we need the model for actor and rollout
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
(
self.actor_module_fsdp,
self.actor_optimizer,
self.actor_lr_scheduler,
self.actor_model_config,
) = self._build_model_optimizer(
model_path=local_path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get(
"enable_gradient_checkpointing", False
),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
enable_activation_offload=self.config.model.get("enable_activation_offload", False),
)
# get the original unwrapped module
if fsdp_version(self.actor_module_fsdp) == 1:
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.config.actor.use_fused_kernels = use_fused_kernels
self.actor = DataParallelPPOActor(
config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer,
)
if self._is_ref:
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
self.ref_module_fsdp = self._build_model_optimizer(
model_path=local_path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
)[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.config.ref.use_fused_kernels = use_fused_kernels
self.ref_policy = DataParallelPPOActor(
config=self.config.ref, actor_module=self.ref_module_fsdp
)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.ref_module_fsdp,
optimizer=None,
lr_scheduler=None,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.ref.checkpoint,
)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
config=self.config.synchronizer,
)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def setup_weight_sync_group(self):
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
model = self.actor_module_fsdp
self.named_modules = []
self.state_dict_meta = []
with self._fsdp_offload_context():
if self.config.actor.strategy == "fsdp":
for name, module in model.named_modules():
if isinstance(module, FSDP):
self.named_modules.append((name, module))
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
realname = (
name_prefix[len(FSDP_PREFIX) :] + "." + name
if name_prefix
else name
)
self.state_dict_meta.append(
(realname, str(param.dtype), tuple(param.shape))
)
param = None
torch.cuda.empty_cache()
else: # fsdp2
for name, param in model.named_parameters():
self.state_dict_meta.append((name, str(param.dtype), tuple(param.shape)))
if torch.distributed.get_rank() == 0:
import ray
master_address, master_port = self.get_availale_master_addr_port()
world_size = self.config.synchronizer.explorer_world_size + 1
print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).")
synchronizer = Synchronizer.get_actor(
namespace=self.config.synchronizer.ray_namespace
)
setup_ref = synchronizer.setup_weight_sync_group.remote(
master_address, master_port, self.state_dict_meta
)
timeout = self.config.synchronizer.sync_timeout
self._model_update_group = init_process_group(
host=master_address,
port=master_port,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
backend="nccl",
timeout=timeout,
world_size=world_size,
device_id=torch.device(f"cuda:{get_device_id()}"),
rank=0,
)
torch.distributed.barrier(group=self._model_update_group)
ray.get(setup_ref)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def sync_weight(self):
with self._fsdp_offload_context():
if self.config.actor.strategy == "fsdp":
for name_prefix, module in self.named_modules:
with FSDP.summon_full_params(module, recurse=False):
if torch.distributed.get_rank() == 0:
for name, param in module.named_parameters():
if isinstance(param, FlatParameter):
continue
torch.distributed.broadcast(
param, 0, group=self._model_update_group
)
param = None
else: # fsdp2
for name, param in self.actor_module_fsdp.named_parameters():
full_param = param.full_tensor().detach().to(device=get_device_id())
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(full_param, 0, group=self._model_update_group)
del full_param
if torch.distributed.get_rank() == 0:
torch.distributed.barrier(group=self._model_update_group)
torch.cuda.synchronize()
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def upload_state_dict(self, trainer_step: int):
with self._fsdp_offload_context():
self.checkpoint_manager.upload_state_dict(trainer_step)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_algorithm(self, algo_config: AlgorithmConfig):
self.actor.set_algorithm(algo_config)
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
# Support all hardwares
data = data.to("cpu") # data will to device with each micro batch on actor.update_policy
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
# perform training
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
global_num_tokens, delta_time
)
metrics["perf/mfu/actor"] = (
estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
)
metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (
1024**3
)
metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (
1024**3
)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
self.actor_lr_scheduler.step()
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during update_actor", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
# when is_lora is True, we use the actor without lora applied to calculate the log_prob
# which is mostly used for ref log_prob calculation
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
# Support all hardwares
from contextlib import nullcontext
is_lora = data.meta_info.pop("is_lora", False)
adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext()
data = data.to(get_device_id())
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
with adapter_ctx:
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
output = DataProto.from_dict(
tensors={"old_log_probs": output, "entropys": entropys},
meta_info={"temperature": self.config.rollout.temperature},
)
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
if self._is_lora:
# if _is_lora, actor without lora applied is the ref
data.meta_info["is_lora"] = True
data = self.compute_log_prob(data)
# this old_log_probs is in fact ref_log_prob
data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]})
return data
assert self._is_ref
# else:
# otherwise, the class have a standalone ref model
# Support all hardwares
data = data.to(get_device_id())
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["temperature"] = self.config.rollout.temperature
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1:
self.ref_policy.actor_module._handle.reshard(True)
return output
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(
self,
local_path,
hdfs_path=None,
global_step=0,
max_ckpt_to_keep=None,
model_state_dict_only=False,
):
from verl.utils.logger import log_with_rank
# only support save and load ckpt for actor
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.save_checkpoint(
local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep,
model_state_dict_only=model_state_dict_only,
)
dist.barrier()
if self._is_lora and hasattr(
getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"
):
lora_save_path = os.path.join(local_path, "lora_adapter")
peft_model = getattr(self, "actor_module", self.actor_module_fsdp)
peft_config = {}
if dist.get_rank() == 0:
os.makedirs(lora_save_path, exist_ok=True)
peft_config = asdict(peft_model.peft_config.get("default", {}))
peft_config["task_type"] = peft_config["task_type"].value
peft_config["peft_type"] = peft_config["peft_type"].value
peft_config["target_modules"] = list(peft_config["target_modules"])
try:
if fsdp_version(self.actor_module_fsdp) > 0:
self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name())
lora_params = layered_summon_lora_params(self.actor_module_fsdp)
if dist.get_rank() == 0:
save_file(
lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")
)
with open(
os.path.join(lora_save_path, "adapter_config.json"),
"w",
encoding="utf-8",
) as f:
json.dump(peft_config, f, ensure_ascii=False, indent=4)
except Exception as e:
log_with_rank(
f"Save LoRA Adapter Error ({e})",
rank=dist.get_rank(),
logger=logger,
log_only_rank_0=True,
)
dist.barrier()
log_with_rank(
f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}",
rank=dist.get_rank(),
logger=logger,
log_only_rank_0=True,
)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
if self._is_actor and self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
if self._is_actor and self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_actor and self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_optimizer_state(self):
print("Clear actor optimizer state")
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())
self.actor_optimizer.state.clear()
self.actor_optimizer.zero_grad()
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def wait_on_save_thread(self) -> None:
self.checkpoint_manager.wait_on_save_thread()
[docs]
class CriticWorker(Worker):
[docs]
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend=get_nccl_backend(),
init_method=os.environ.get("DIST_INIT_METHOD", None),
timeout=timedelta(seconds=self.config.synchronizer.sync_timeout),
)
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
device_name,
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# set FSDP offload params
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
# note: no need to conduct `ppo_mini_batch_size *= rollout_n` anymore
self.config.ppo_mini_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
if self.config.ppo_micro_batch_size is not None:
self.config.ppo_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.forward_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
if self.config.ppo_micro_batch_size_per_gpu is not None:
assert (
self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
assert (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0
), f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
self._is_lora = self.config.model.get("lora_rank", 0) > 0
def _build_critic_model_optimizer(self, config): # noqa: C901
# the following line is necessary
from torch import optim
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import load_valuehead_model, print_model_size
from verl.utils.torch_dtypes import PrecisionType
use_shm = config.model.get("use_shm", False)
local_path = copy_to_local(config.model.path, use_shm=use_shm)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)
self.tokenizer = hf_tokenizer(
tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
self.processor = hf_processor(
tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
if self.config.model.get("custom_chat_template", None) is not None:
if self.processor is not None:
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(
self.config.model.get("override_config", OmegaConf.create())
)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f"Critic overriding config {override_config_kwargs}")
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig
critic_model_config = AutoConfig.from_pretrained(
local_path,
attn_implementation="flash_attention_2",
trust_remote_code=config.model.get("trust_remote_code", False),
)
critic_model_config.num_labels = 1
# patch for kimi-vl
if getattr(critic_model_config, "model_type", None) == "kimi_vl":
critic_model_config.text_config.topk_method = "greedy"
init_context = get_init_weight_context_manager(
use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
critic_model_config.classifier_dropout = 0.0
critic_model_config.hidden_dropout = "0"
critic_model_config.summary_dropout_prob = 0.0
critic_module = load_valuehead_model(
local_path,
torch_dtype,
critic_model_config,
config.model.get("trust_remote_code", False),
)
use_remove_padding = config.model.get("use_remove_padding", False)
apply_monkey_patch(
model=critic_module,
use_remove_padding=use_remove_padding,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
)
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
if config.model.get("enable_gradient_checkpointing", False):
critic_module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
if self._is_lora:
print("Applying LoRA to critic module")
critic_module.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
"task_type": TaskType.CAUSAL_LM,
"r": self.config.model.lora_rank,
"lora_alpha": self.config.model.lora_alpha,
"target_modules": convert_to_regular_types(self.config.model.target_modules),
"bias": "none",
}
critic_module = get_peft_model(critic_module, LoraConfig(**lora_config))
if self.rank == 0:
print_model_size(critic_module)
self.critic_model_config = critic_model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("reduce_dtype", "fp32")
)
buffer_dtype = PrecisionType.to_dtype(
mixed_precision_config.get("buffer_dtype", "fp32")
)
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype
)
auto_wrap_policy = get_fsdp_wrap_policy(
module=critic_module,
config=self.config.model.fsdp_config.wrap_policy,
is_lora=self.config.model.get("lora_rank", 0) > 0,
)
log_gpu_memory_usage("Before critic FSDP", logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation
if config.strategy == "fsdp":
critic_module = FSDP(
critic_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=self.config.model.fsdp_config.forward_prefetch,
device_mesh=self.device_mesh,
cpu_offload=None,
)
elif config.strategy == "fsdp2":
assert (
CPUOffloadPolicy is not None
), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
)
offload_policy = None
if fsdp_config.offload_policy:
self._is_offload_param = False
self._is_offload_optimizer = False
offload_policy = CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": offload_policy,
"reshard_after_forward": fsdp_config.reshard_after_forward,
}
full_state = critic_module.state_dict()
apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)
else:
raise NotImplementedError(f"Unknown strategy {config.strategy}")
if config.model.get("enable_activation_offload", False):
enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False)
enable_activation_offloading(
critic_module, config.strategy, enable_gradient_checkpointing
)
log_gpu_memory_usage("After critic FSDP", logger=None)
critic_optimizer = optim.AdamW(
critic_module.parameters(),
lr=config.optim.lr,
betas=config.optim.get("betas", (0.9, 0.999)),
weight_decay=config.optim.get("weight_decay", 1e-2),
)
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
warmup_style = config.optim.get("warmup_style", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
from verl.utils.torch_functional import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
)
if warmup_style == "constant":
critic_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
critic_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=critic_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
return critic_module, critic_optimizer, critic_lr_scheduler
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from verl.workers.critic import DataParallelPPOCritic
(
self.critic_module,
self.critic_optimizer,
self.critic_lr_scheduler,
) = self._build_critic_model_optimizer(self.config)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
log_gpu_memory_usage("After offload critic model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
log_gpu_memory_usage("After offload critic optimizer during init", logger=logger)
self.critic = DataParallelPPOCritic(
config=self.config,
critic_module=self.critic_module,
critic_optimizer=self.critic_optimizer,
)
self.flops_counter = FlopsCounter(self.critic_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.critic_module,
optimizer=self.critic_optimizer,
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.checkpoint,
)
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
# Support all hardwares
data = data.to(get_device_id())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
micro_batch_size = self.config.forward_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
return output
[docs]
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
# Support all hardwares
data = data.to(get_device_id())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id())
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_critic", logger=None) as timer:
metrics = self.critic.update_critic(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
global_num_tokens, delta_time
)
metrics["perf/mfu/critic"] = (
estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
)
self.critic_lr_scheduler.step()
lr = self.critic_lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr
output = DataProto(batch=None, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
output = output.to("cpu")
return output
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.save_checkpoint(
local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep,
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_optimizer_state(self):
print("Clear critic optimizer state")
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id())
self.critic_optimizer.state.clear()
self.critic_optimizer.zero_grad()
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)
[docs]
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def wait_on_save_thread(self) -> None:
self.checkpoint_manager.wait_on_save_thread()