Source code for trinity.algorithm.sample_strategy.sample_strategy

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple

from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto
from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
from trinity.common.experience import Experiences
from trinity.utils.registry import Registry
from trinity.utils.timer import Timer

SAMPLE_STRATEGY = Registry("sample_strategy")


[docs] class SampleStrategy(ABC):
[docs] def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs) -> None: self.pad_token_id = buffer_config.pad_token_id self.trainer_type = trainer_type
[docs] @abstractmethod def sample(self, step: int) -> Tuple[Any, Dict, List]: """Sample data from buffer. Args: step (`int`): The step number of current step. Returns: `Any`: The sampled data. `Dict`: Metrics for logging. `List`: Representative data for logging. """
# Experimental API
[docs] @abstractmethod def warmup_state(self, step: int) -> Tuple[bool, bool]: """Check the warmup state of the current step. Args: step (`int`): The step number of current step. Returns: `bool`: Current step is in warmup or not. `bool`: Warmup is finished on this step or not. """
[docs] @classmethod @abstractmethod def default_args(cls) -> dict: """Get the default arguments of the sample strategy."""
[docs] @SAMPLE_STRATEGY.register_module("warmup") class WarmupSampleStrategy(SampleStrategy): """The default sample strategy."""
[docs] def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): super().__init__(buffer_config, trainer_type) self.exp_buffer = get_buffer_reader( buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore ) self.sft_warmup_steps = buffer_config.trainer_input.sft_warmup_steps if self.sft_warmup_steps > 0 and buffer_config.trainer_input.sft_warmup_dataset is None: raise ValueError("sft_warmup_dataset is required when sft_warmup_steps > 0") if buffer_config.trainer_input.sft_warmup_dataset is not None: self.sft_buffer = get_buffer_reader( buffer_config.trainer_input.sft_warmup_dataset, buffer_config ) else: self.sft_buffer = None
[docs] def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: metrics = {} with Timer(metrics, "read_time"): if step <= self.sft_warmup_steps: exp_list = self.sft_buffer.read() else: exp_list = self.exp_buffer.read() repr_samples = representative_sample(exp_list) with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore if self.trainer_type == "verl": with Timer(metrics, "convert_time"): data = to_data_proto(exps) return data, metrics, repr_samples else: raise NotImplementedError(f"backend {self.trainer_type} is not supported")
[docs] def warmup_state(self, step: int) -> Tuple[bool, bool]: return step <= self.sft_warmup_steps, step == self.sft_warmup_steps
[docs] @classmethod def default_args(cls) -> dict: return {}
[docs] @SAMPLE_STRATEGY.register_module("default") class DefaultSampleStrategy(SampleStrategy):
[docs] def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): super().__init__(buffer_config, trainer_type) self.exp_buffer = get_buffer_reader( buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore )
[docs] def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: metrics = {} with Timer(metrics, "read_time"): exp_list = self.exp_buffer.read() repr_samples = representative_sample(exp_list) with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore if self.trainer_type == "verl": with Timer(metrics, "convert_time"): data = to_data_proto(exps) return data, metrics, repr_samples else: raise NotImplementedError(f"backend {self.trainer_type} is not supported")
[docs] def warmup_state(self, step: int) -> Tuple[bool, bool]: return False, False
[docs] @classmethod def default_args(cls) -> dict: return {}
[docs] @SAMPLE_STRATEGY.register_module("dpo") class DPOSampleStrategy(WarmupSampleStrategy):
[docs] def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: metrics = {} with Timer(metrics, "read_time"): if step <= self.sft_warmup_steps: exp_list = self.sft_buffer.read() else: exp_list = self.exp_buffer.read() repr_samples = representative_sample(exp_list) with Timer(metrics, "gather_time"): exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore if self.trainer_type == "verl": with Timer(metrics, "convert_time"): data = to_data_proto(exps) return data, metrics, repr_samples else: raise NotImplementedError(f"backend {self.trainer_type} is not supported")