"""Utils for ccompatibility issues with verl."""
import os
from logging import Logger
from typing import List
import numpy as np
import torch
from verl import DataProto
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from trinity.common.config import Config
from trinity.common.experience import (
Experience,
gather_action_masks,
gather_attention_masks,
gather_response_attrs,
gather_token_ids,
split_dpo_experience_to_single_turn,
)
[文档]
def to_data_proto(
experiences: List[Experience], pad_token_id: int, logger: Logger
) -> DataProto: # noqa: C901
"""Convert List[Experience] to verl DataProto."""
assert len(experiences) > 0, "No experiences provided."
if experiences[0].experience_type == "dpo":
experiences = split_dpo_experience_to_single_turn(experiences)
max_prompt_length = max([exp.prompt_length for exp in experiences])
max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore
attention_mask = gather_attention_masks(
experiences, max_prompt_length, max_response_length
).long()
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
tokens = gather_token_ids(
experiences, max_prompt_length, max_response_length, pad_token_id
).long()
batch_dict = {
"uid": np.array([exp.eid.tid for exp in experiences]),
"unique_ids": np.array([exp.eid.uid for exp in experiences]),
"position_ids": position_ids,
"input_ids": tokens,
"responses": tokens[:, max_prompt_length:],
"attention_mask": attention_mask,
"response_mask": gather_action_masks(experiences, max_response_length),
}
have_reward = all(exp.reward is not None for exp in experiences)
have_token_level_reward = all(exp.token_level_reward is not None for exp in experiences)
if have_reward or have_token_level_reward:
assert all(exp.logprobs is not None for exp in experiences), "No logprobs provided."
if have_token_level_reward:
if have_reward:
logger.warning(
"Both experiences.rewards and experiences.token_level_rewards are provided. "
"Using experiences.token_level_rewards."
)
token_level_rewards = gather_response_attrs(
experiences, "token_level_reward", max_response_length
)
else:
token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32)
eos_mask_idx = cumsum.argmax(dim=-1)
token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor(
[exp.reward for exp in experiences]
)
token_level_rewards = token_level_rewards[:, max_prompt_length:]
batch_dict.update(
{
"token_level_scores": token_level_rewards,
"old_log_probs": gather_response_attrs(
experiences, "logprobs", max_response_length
),
}
)
for attr in ["advantages", "returns", "teacher_logprobs"]:
if all(getattr(exp, attr, None) is not None for exp in experiences):
batch_dict[attr] = gather_response_attrs(experiences, attr, max_response_length)
if all(exp.multi_modal_inputs is not None for exp in experiences):
keys = experiences[0].multi_modal_inputs.keys()
batch_dict["multi_modal_inputs"] = np.array(
[{key: exp.multi_modal_inputs[key] for key in keys} for exp in experiences], # type: ignore
dtype=object,
)
custom_fields_set = set(tuple(exp.custom_fields) for exp in experiences)
if len(custom_fields_set) == 1:
custom_fields = list(custom_fields_set)[0]
for custom_field in custom_fields:
batch_dict[custom_field.destination_field] = torch.tensor(
[exp.info[custom_field.source_field] for exp in experiences],
dtype=custom_field.data_type,
)
else:
raise ValueError("Custom fields are not consistent across experiences.")
return DataProto.from_single_dict(batch_dict)
[文档]
def compute_data_metrics(batch: DataProto) -> dict:
"""
Computes various metrics from a batch of data for PPO training.
Modified from verl.trainer.ppo.metric_utils.compute_data_metrics
This function calculates metrics related to scores, rewards, advantages, returns, values,
and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
for each metric category.
Args:
batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
Returns:
A dictionary of metrics including:
- critic/score/mean, max, min: Statistics about sequence scores
- critic/rewards/mean, max, min: Statistics about sequence rewards
- critic/advantages/mean, max, min: Statistics about advantages
- critic/returns/mean, max, min: Statistics about returns
- critic/values/mean, max, min: Statistics about critic values
- critic/vf_explained_var: Explained variance of the value function
- response_length/mean, max, min, clip_ratio: Statistics about response lengths
- prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
"""
metrics = {}
if "token_level_rewards" in batch.batch and "token_level_scores" in batch.batch:
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
metrics.update(
{
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
}
)
max_response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]
metrics.update(
{
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(
torch.eq(response_length, max_response_length).float()
)
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(
torch.eq(prompt_length, max_prompt_length).float()
)
.detach()
.item(),
}
)
if "advantages" in batch.batch:
# adv
advantages = batch.batch["advantages"]
if response_mask.numel() > 0:
valid_adv = torch.masked_select(advantages, response_mask)
else:
valid_adv = torch.zeros(1)
metrics.update(
{
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
}
)
if "returns" in batch.batch:
# returns
returns = batch.batch["returns"]
if response_mask.numel() > 0:
valid_returns = torch.masked_select(returns, response_mask)
else:
valid_returns = torch.zeros(1)
metrics.update(
{
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
}
)
return metrics
[文档]
def get_latest_hf_checkpoint_path(config: Config):
"""Get the latest huggingface checkpoint path"""
if config.trainer.trainer_type != "verl":
raise ValueError("This function is only for verl trainer.")
checkpoint_dir = find_latest_ckpt_path(config.checkpoint_job_dir)
hf_checkpoint_dir = os.path.join(checkpoint_dir, "actor", "huggingface")
if not os.path.exists(hf_checkpoint_dir):
raise ValueError(f"No huggingface checkpoint found in {hf_checkpoint_dir}")
return hf_checkpoint_dir