# -*- coding: utf-8 -*-
"""veRL Trainer Class
Modified from verl/trainer/ppo/ray_trainer.py
"""
import os
import sys
from pprint import pprint
from typing import Dict, List
import pandas as pd
import ray
import torch
from omegaconf import OmegaConf
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import (
RayClassWithInitArgs,
RayPPOTrainer,
RayWorkerGroup,
ResourcePoolManager,
Role,
_timer,
create_colocated_worker_cls,
find_latest_ckpt_path,
)
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs
from trinity.algorithm import ADVANTAGE_FN, KL_FN, SAMPLE_STRATEGY
from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm
from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import Config
from trinity.common.experience import Experiences
from trinity.trainer.trainer import TrainEngineWrapper
from trinity.utils.monitor import MONITOR
class _InternalDataLoader:
def __init__(self, config):
self.config = config
self.dataset = None
self.index = 0
self.experience_buffer = None
def state_dict(self):
return None
def load_state_dict(self, *args, **kwargs):
pass
def __getstate__(self):
state = self.__dict__.copy()
return state
def __setstate__(self, state):
self.__dict__.update(state)
def __iter__(self):
self.index = 0
return self
def __next__(self):
raise StopIteration
[docs]
class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper):
"""A wrapper for verl.trainer.ppo.RayPPOTrainer."""
[docs]
def __init__(
self,
global_config: Config,
):
train_config = global_config.trainer
config = OmegaConf.structured(train_config.trainer_config)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from trinity.trainer.verl.fsdp_workers import (
ActorRolloutRefWorker,
CriticWorker,
)
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
raise NotImplementedError("Not support megatron for now.")
else:
raise NotImplementedError
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
resource_pool_manager = ResourcePoolManager(
resource_pool_spec=resource_pool_spec, mapping=mapping
)
self.algorithm_config = global_config.algorithm
self.algorithm = None
self.algorithm_manager = AlgorithmManager(global_config)
# specify advantage function for various rft algorithms
algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type)
if algorithm.use_advantage:
self.advantage_fn = ADVANTAGE_FN.get(self.algorithm_config.advantage_fn)(
**self.algorithm_config.advantage_fn_args
)
self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)(
**self.algorithm_config.kl_penalty_fn_args
)
self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)(
buffer_config=global_config.buffer,
trainer_type=global_config.trainer.trainer_type,
**global_config.algorithm.sample_strategy_args,
)
super().__init__(
config,
tokenizer,
role_worker_mapping,
resource_pool_manager,
ray_worker_group_cls,
)
self.init_workers()
self.logger = MONITOR.get(global_config.monitor.monitor_type)(
project=config.trainer.project_name,
name=config.trainer.experiment_name,
role=global_config.trainer.name,
config=global_config,
)
self.reset_experiences_example_table()
def _validate_config(self): # TODO
algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type)
self.use_critic = algorithm.use_critic
super()._validate_config()
[docs]
def init_workers(self):
"""Initialize distributed training workers using Ray backend.
Creates:
1. Ray resource pools from configuration
2. Worker groups for each role (actor, critic, etc.)
"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {
pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()
}
# create actor and rollout
if self.hybrid_engine:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.ActorRollout],
config=self.config.actor_rollout_ref,
role="actor",
)
self.resource_pool_to_cls[resource_pool]["actor"] = actor_rollout_cls
else:
raise NotImplementedError
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.Critic], config=self.config.critic
)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
role="ref",
)
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
# create a reward model if reward_fn is None
if self.use_rm:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model
)
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`.
# Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
wg_kwargs[
"ray_wait_register_center_timeout"
] = self.config.trainer.ray_wait_register_center_timeout
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
device_name=self.device_name,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
if self.use_critic:
self.critic_wg = all_wg["critic"]
self.critic_wg.init_model()
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = all_wg["ref"]
self.ref_policy_wg.init_model()
if self.use_rm:
self.rm_wg = all_wg["rm"]
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg = all_wg["actor"]
self.actor_rollout_wg.init_model()
[docs]
def reset_experiences_example_table(self):
self.sample_exps_to_log = []
@property
def train_step_num(self) -> int:
return self.global_steps
[docs]
def prepare(self):
self.actor_rollout_wg.setup_weight_sync_group()
# The global step counter, initialized to 0
# It represents the total number of training steps completed so far
# We increment this counter at the beginning of each training step
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
pprint(f"Initial validation metrics: {val_metrics}")
self.logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
self.train_dataloader = _InternalDataLoader(self.config)
# TODO: compute total training steps
self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize
[docs]
def train_step(self) -> bool: # noqa C901
metrics = {}
try:
batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1)
prefix_metrics(sample_metrics, "sample", metrics)
except StopIteration:
print("No more data to train. Stop training.")
return False
self.global_steps += 1
timing_raw = {}
algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps)
algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type)
if self.algorithm != algorithm:
self.actor_rollout_wg.set_algorithm(algorithm_config)
if self.algorithm == SFTAlgorithm:
self.sft_to_rft()
self.algorithm = algorithm
with _timer("step", timing_raw):
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
if self.algorithm.can_balance_batch and self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics) # TODO this may affect multi-turn
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(
batch.batch["attention_mask"], dim=-1
).tolist()
if self.algorithm.use_reference: # ref_logprob may not be used
# compute reference log_prob
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
if self.algorithm.use_critic:
with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
if self.algorithm.use_advantage:
with _timer("adv", timing_raw):
# compute kl penalty
batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch)
metrics.update(prefix_metrics(kl_metrics, prefix="critic"))
# compute advantages, executed on the driver process
batch, _ = self.advantage_fn(batch)
# update critic
if self.algorithm.use_critic:
with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if (
not self.algorithm.use_critic
or self.config.trainer.critic_warmup <= self.global_steps
):
# update actor
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
# TODO add send weight explorer
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
if (
self.config.trainer.save_freq > 0
and self.global_steps % self.config.trainer.save_freq == 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
if self.algorithm.use_advantage: # TODO
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(
compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)
)
if self.algorithm.use_advantage and self.config.enable_preview: # TODO
self._log_experiences(exp_samples)
# TODO: make a canonical logger that supports various backend
self.logger.log(data=metrics, step=self.global_steps)
train_status = self.global_steps < self.total_training_steps
if not train_status or self.algorithm_manager.need_save(self.global_steps):
if (
self.config.trainer.save_freq == 0
or self.global_steps % self.config.trainer.save_freq != 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
return train_status
def _log_single_experience(
self, experiences: Experiences, idx: int, skip_special_tokens: bool
) -> None:
reward = experiences.rewards[idx]
attn_mask = experiences.attention_masks[idx].bool()
prompt_token = experiences.tokens[idx][: experiences.prompt_length][
attn_mask[: experiences.prompt_length]
]
response_token = experiences.tokens[idx][experiences.prompt_length :][
attn_mask[experiences.prompt_length :]
]
prompt_text = self.tokenizer.decode(prompt_token, skip_special_tokens=skip_special_tokens)
response_text = self.tokenizer.decode(
response_token, skip_special_tokens=skip_special_tokens
)
new_row = pd.DataFrame(
{
"step": [self.global_steps],
"reward": [reward],
"prompt": [prompt_text],
"response": [response_text],
}
)
self.sample_exps_to_log = pd.concat([self.sample_exps_to_log, new_row], ignore_index=True)
def _log_experiences(self, samples: List[Dict]) -> None:
self.sample_exps_to_log.extend(samples)
if self.global_steps % self.config.trainer.sync_freq == 0:
self.logger.log_table(
"rollout_examples", pd.DataFrame(self.sample_exps_to_log), self.global_steps
)
self.reset_experiences_example_table()
[docs]
def save_checkpoint(self) -> None:
self._save_checkpoint()
[docs]
def sync_weight(self) -> None:
self.actor_rollout_wg.sync_weight()
[docs]
def sft_to_rft(self) -> None:
# load from hdfs
if self.config.trainer.default_hdfs_dir is not None:
raise NotImplementedError("load from hdfs is not implemented yet")
else:
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
if not os.path.isabs(checkpoint_folder):
working_dir = os.getcwd()
checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
# find global_step_folder
if self.config.trainer.resume_mode == "auto":
if global_step_folder is None:
print("Training from scratch")
return
else:
if not (self.config.trainer.resume_from_path and global_step_folder is not None):
assert isinstance(
self.config.trainer.resume_mode, str
), "resume ckpt must be str type"
assert (
"global_step_" in self.config.trainer.resume_mode
), "resume ckpt must specify the global_steps"
global_step_folder = self.config.trainer.resume_mode
if not os.path.isabs(global_step_folder):
working_dir = os.getcwd()
global_step_folder = os.path.join(working_dir, global_step_folder)
print(f"Load from checkpoint folder: {global_step_folder}")
# set global step
global_steps = int(global_step_folder.split("global_step_")[-1])
assert self.global_steps == global_steps + 1
print(f"Resuming from {global_step_folder}")
actor_path = os.path.join(global_step_folder, "actor")
print(f"Loading actor from {actor_path} to ref_policy_wg")
self.ref_policy_wg.load_checkpoint(actor_path, del_local_after_load=False)
self.actor_rollout_wg.clear_optimizer_state()
if self.use_critic:
self.critic_wg.clear_optimizer_state()
print("sft to rft finished")
[docs]
def shutdown(self) -> None:
pass