from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto
from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
from trinity.common.experience import Experiences
from trinity.utils.registry import Registry
from trinity.utils.timer import Timer
SAMPLE_STRATEGY = Registry("sample_strategy")
[docs]
class SampleStrategy(ABC):
[docs]
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs) -> None:
self.pad_token_id = buffer_config.pad_token_id
self.trainer_type = trainer_type
[docs]
@abstractmethod
def sample(self, step: int) -> Tuple[Any, Dict, List]:
"""Sample data from buffer.
Args:
step (`int`): The step number of current step.
Returns:
`Any`: The sampled data.
`Dict`: Metrics for logging.
`List`: Representative data for logging.
"""
# Experimental API
[docs]
@abstractmethod
def warmup_state(self, step: int) -> Tuple[bool, bool]:
"""Check the warmup state of the current step.
Args:
step (`int`): The step number of current step.
Returns:
`bool`: Current step is in warmup or not.
`bool`: Warmup is finished on this step or not.
"""
[docs]
@classmethod
@abstractmethod
def default_args(cls) -> dict:
"""Get the default arguments of the sample strategy."""
[docs]
@SAMPLE_STRATEGY.register_module("warmup")
class WarmupSampleStrategy(SampleStrategy):
"""The default sample strategy."""
[docs]
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
super().__init__(buffer_config, trainer_type)
self.exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
)
self.sft_warmup_steps = buffer_config.trainer_input.sft_warmup_steps
if self.sft_warmup_steps > 0 and buffer_config.trainer_input.sft_warmup_dataset is None:
raise ValueError("sft_warmup_dataset is required when sft_warmup_steps > 0")
if buffer_config.trainer_input.sft_warmup_dataset is not None:
self.sft_buffer = get_buffer_reader(
buffer_config.trainer_input.sft_warmup_dataset, buffer_config
)
else:
self.sft_buffer = None
[docs]
def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
if step <= self.sft_warmup_steps:
exp_list = self.sft_buffer.read()
else:
exp_list = self.exp_buffer.read()
repr_samples = representative_sample(exp_list)
with Timer(metrics, "gather_time"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
if self.trainer_type == "verl":
with Timer(metrics, "convert_time"):
data = to_data_proto(exps)
return data, metrics, repr_samples
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
[docs]
def warmup_state(self, step: int) -> Tuple[bool, bool]:
return step <= self.sft_warmup_steps, step == self.sft_warmup_steps
[docs]
@classmethod
def default_args(cls) -> dict:
return {}
[docs]
@SAMPLE_STRATEGY.register_module("default")
class DefaultSampleStrategy(SampleStrategy):
[docs]
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
super().__init__(buffer_config, trainer_type)
self.exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
)
[docs]
def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
exp_list = self.exp_buffer.read()
repr_samples = representative_sample(exp_list)
with Timer(metrics, "gather_time"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
if self.trainer_type == "verl":
with Timer(metrics, "convert_time"):
data = to_data_proto(exps)
return data, metrics, repr_samples
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
[docs]
def warmup_state(self, step: int) -> Tuple[bool, bool]:
return False, False
[docs]
@classmethod
def default_args(cls) -> dict:
return {}
[docs]
@SAMPLE_STRATEGY.register_module("dpo")
class DPOSampleStrategy(WarmupSampleStrategy):
[docs]
def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
if step <= self.sft_warmup_steps:
exp_list = self.sft_buffer.read()
else:
exp_list = self.exp_buffer.read()
repr_samples = representative_sample(exp_list)
with Timer(metrics, "gather_time"):
exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore
if self.trainer_type == "verl":
with Timer(metrics, "convert_time"):
data = to_data_proto(exps)
return data, metrics, repr_samples
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")