Source code for trinity.trainer.verl.utils

"""Utils for ccompatibility issues with verl."""

import numpy as np
import torch
from verl import DataProto
from verl.trainer.ppo.metric_utils import _compute_response_info

from trinity.common.experience import Experiences


[docs] def to_data_proto(experiences: Experiences) -> DataProto: """Convert Experiences to verl DataProto.""" attention_mask = experiences.attention_masks cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { "uid": np.array([eid.tid for eid in experiences.eids]), "unique_ids": np.array([eid.uid for eid in experiences.eids]), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), "attention_mask": attention_mask.long(), "response_mask": ( experiences.action_masks.long() if hasattr(experiences, "action_masks") and experiences.action_masks is not None else attention_mask[:, experiences.prompt_length :].long() ), } if experiences.rewards is not None: token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) eos_mask_idx = cumsum.argmax(dim=-1) token_level_rewards[ torch.arange(experiences.batch_size), eos_mask_idx ] = experiences.rewards token_level_rewards = token_level_rewards[:, experiences.prompt_length :] batch_dict.update( { "token_level_scores": token_level_rewards, "old_log_probs": experiences.logprobs, # type: ignore } ) if experiences.advantages is not None: batch_dict["advantages"] = experiences.advantages if experiences.returns is not None: batch_dict["returns"] = experiences.returns if experiences.custom_fields: for field in experiences.custom_fields: if hasattr(experiences, field): batch_dict[field] = getattr(experiences, field) return DataProto.from_single_dict(batch_dict)
[docs] def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: """ Computes various metrics from a batch of data for PPO training. Modified from verl.trainer.ppo.metric_utils.compute_data_metrics This function calculates metrics related to scores, rewards, advantages, returns, values, and sequence lengths from a batch of data. It provides statistical information (mean, max, min) for each metric category. Args: batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. use_critic: Whether to include critic-specific metrics. Defaults to True. Returns: A dictionary of metrics including: - critic/score/mean, max, min: Statistics about sequence scores - critic/rewards/mean, max, min: Statistics about sequence rewards - critic/advantages/mean, max, min: Statistics about advantages - critic/returns/mean, max, min: Statistics about returns - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) - response_length/mean, max, min, clip_ratio: Statistics about response lengths - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths """ metrics = {} if "token_level_rewards" in batch.batch and "token_level_scores" in batch.batch: sequence_score = batch.batch["token_level_scores"].sum(-1) sequence_reward = batch.batch["token_level_rewards"].sum(-1) metrics.update( { # score "critic/score/mean": torch.mean(sequence_score).detach().item(), "critic/score/max": torch.max(sequence_score).detach().item(), "critic/score/min": torch.min(sequence_score).detach().item(), # reward "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), "critic/rewards/max": torch.max(sequence_reward).detach().item(), "critic/rewards/min": torch.min(sequence_reward).detach().item(), } ) max_response_length = batch.batch["responses"].shape[-1] prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) response_info = _compute_response_info(batch) prompt_length = response_info["prompt_length"] response_length = response_info["response_length"] metrics.update( { # response length "response_length/mean": torch.mean(response_length).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), "response_length/clip_ratio": torch.mean( torch.eq(response_length, max_response_length).float() ) .detach() .item(), # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), "prompt_length/min": torch.min(prompt_length).detach().item(), "prompt_length/clip_ratio": torch.mean( torch.eq(prompt_length, max_prompt_length).float() ) .detach() .item(), } ) if "advantages" in batch.batch: # adv advantages = batch.batch["advantages"] valid_adv = torch.masked_select(advantages, response_mask) metrics.update( { # adv "critic/advantages/mean": torch.mean(valid_adv).detach().item(), "critic/advantages/max": torch.max(valid_adv).detach().item(), "critic/advantages/min": torch.min(valid_adv).detach().item(), } ) if "returns" in batch.batch: # returns returns = batch.batch["returns"] valid_returns = torch.masked_select(returns, response_mask) metrics.update( { "critic/returns/mean": torch.mean(valid_returns).detach().item(), "critic/returns/max": torch.max(valid_returns).detach().item(), "critic/returns/min": torch.min(valid_returns).detach().item(), } ) return metrics