Source code for trinity.algorithm.sample_strategy.sample_strategy

from abc import ABC, abstractmethod
from typing import Dict, List, Tuple

from trinity.algorithm.sample_strategy.utils import representative_sample
from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
from trinity.common.experience import Experience, Experiences
from trinity.utils.annotations import Deprecated
from trinity.utils.monitor import gather_metrics
from trinity.utils.registry import Registry
from trinity.utils.timer import Timer

SAMPLE_STRATEGY = Registry("sample_strategy")


[docs] class SampleStrategy(ABC):
[docs] def __init__(self, buffer_config: BufferConfig, **kwargs) -> None: self.pad_token_id = buffer_config.pad_token_id
[docs] def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict): metric_list = [ {"model_version": exp.info["model_version"]} for exp in exp_list if "model_version" in exp.info ] metrics.update(gather_metrics(metric_list, "sample"))
[docs] @abstractmethod async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: """Sample data from buffer. Args: step (`int`): The step number of current step. Returns: `Experiences`: The sampled Experiences data. `Dict`: Metrics for logging. `List`: Representative data for logging. """
[docs] @classmethod @abstractmethod def default_args(cls) -> dict: """Get the default arguments of the sample strategy."""
[docs] @abstractmethod def state_dict(self) -> dict: """Get the state dict of the sample strategy."""
[docs] @abstractmethod def load_state_dict(self, state_dict: dict) -> None: """Load the state dict of the sample strategy."""
[docs] @SAMPLE_STRATEGY.register_module("default") class DefaultSampleStrategy(SampleStrategy):
[docs] def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type]
[docs] async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: metrics = {} with Timer(metrics, "time/read_experience"): exp_list = await self.exp_buffer.read_async() repr_samples = representative_sample(exp_list) self.set_model_version_metric(exp_list, metrics) with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore return exps, metrics, repr_samples
[docs] @classmethod def default_args(cls) -> dict: return {}
[docs] def state_dict(self) -> dict: return self.exp_buffer.state_dict()
[docs] def load_state_dict(self, state_dict: dict) -> None: if state_dict: self.exp_buffer.load_state_dict(state_dict)
[docs] @Deprecated @SAMPLE_STRATEGY.register_module("warmup") class WarmupSampleStrategy(DefaultSampleStrategy): """The warmup sample strategy. Deprecated, keep this class for backward compatibility only. Please use DefaultSampleStrategy instead."""
[docs] def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config)