Source code for trinity.algorithm.sample_strategy.mix_sample_strategy

import copy
from math import ceil
from typing import Any, Dict, List, Tuple

import numpy as np
import torch
from verl.trainer.ppo.ray_trainer import DataProto

from trinity.algorithm.sample_strategy.sample_strategy import (
    SAMPLE_STRATEGY,
    SampleStrategy,
)
from trinity.algorithm.sample_strategy.utils import representative_sample
from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
from trinity.common.experience import Experiences
from trinity.utils.timer import Timer


[docs] @SAMPLE_STRATEGY.register_module("mix") class MixSampleStrategy(SampleStrategy): """The default sample strategy."""
[docs] def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): super().__init__(buffer_config, trainer_type) self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) tot_batch_size = buffer_config.read_batch_size expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) # experience buffer usual_buffer_config = copy.deepcopy(buffer_config) usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size self.usual_exp_buffer = get_buffer_reader( buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore ) if buffer_config.trainer_input.sft_warmup_dataset is None: raise ValueError( "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" ) # expert experience buffer expert_buffer_config = copy.deepcopy(buffer_config) expert_buffer_config.read_batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config )
[docs] def sample(self, step: int) -> Tuple[Any, Dict, List]: metrics = {} with Timer(metrics, "read_time"): usual_exp_list = self.usual_exp_buffer.read() for exp in usual_exp_list: if exp.info is None: exp.info = {} exp.info["is_expert"] = False expert_exp_list = self.expert_exp_buffer.read() for exp in expert_exp_list: exp.reward = 0.0 exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) if exp.info is None: exp.info = {} exp.info["is_expert"] = True exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore if self.trainer_type == "verl": with Timer(metrics, "convert_time"): data = to_data_proto_mix(exps, is_expert_mask) return data, metrics, repr_samples else: raise NotImplementedError(f"backend {self.trainer_type} is not supported")
[docs] @classmethod def default_args(cls) -> Dict: return { "expert_data_ratio": 0.5, }
[docs] def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: attention_mask = experiences.attention_masks cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() batch_dict = { "uid": np.array(experiences.run_ids), "position_ids": position_ids, "input_ids": experiences.tokens.long(), "responses": experiences.tokens[:, experiences.prompt_length :].long(), "attention_mask": attention_mask.long(), "response_mask": ( experiences.action_masks[:, experiences.prompt_length :].long() if hasattr(experiences, "action_masks") and experiences.action_masks is not None else attention_mask[:, experiences.prompt_length :].long() ), "is_expert_mask": is_expert_mask, } if experiences.rewards is not None: token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) eos_mask_idx = cumsum.argmax(dim=-1) token_level_rewards[ torch.arange(experiences.batch_size), eos_mask_idx ] = experiences.rewards token_level_rewards = token_level_rewards[:, experiences.prompt_length :] batch_dict.update( { "token_level_scores": token_level_rewards, "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore } ) return DataProto.from_single_dict(batch_dict)