Source code for trinity.buffer.storage.sql

"""SQL database storage"""

import time
from abc import abstractmethod
from typing import Dict, List, Optional

import ray
from datasets import Dataset
from sqlalchemy import asc, desc
from sqlalchemy.orm import sessionmaker

from trinity.buffer.schema import init_engine
from trinity.buffer.schema.formatter import FORMATTER, TaskFormatter
from trinity.buffer.utils import default_storage_path, retry_session
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.experience import Experience
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.log import get_logger


[docs] class SQLStorage: """ An Storage based on SQL Database. If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor, and provide a remote interface to the local database. For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), please set `wrap_in_ray` to `True`. """
[docs] def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger = get_logger(f"sql_{storage_config.name}", in_ray_actor=True) if storage_config.path is None: storage_config.path = default_storage_path(storage_config, config) self.engine, self.table_model_cls = init_engine( db_url=storage_config.path, table_name=storage_config.name, schema_type=storage_config.schema_type, ) self.logger.info(f"Init SQL storage at {storage_config.path}") self.session = sessionmaker(bind=self.engine) self.max_retry_times = storage_config.max_retry_times self.max_retry_interval = storage_config.max_retry_interval self.ref_count = 0 self.stopped = False # Assume that the auto-increment ID starts counting from 1, so the default offset should be 0. self.offset = storage_config.index
[docs] @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): if storage_config.schema_type is None: storage_cls = SQLTaskStorage else: storage_cls = SQLExperienceStorage if storage_config.wrap_in_ray: return ( ray.remote(storage_cls) .options( name=f"sql-{storage_config.name}", namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, max_concurrency=5, ) .remote(storage_config, config) ) else: return storage_cls(storage_config, config)
[docs] @abstractmethod def write(self, data: List) -> None: """Write a batch of data."""
[docs] @abstractmethod def read(self, batch_size: Optional[int] = None) -> List: """Read a batch of data."""
[docs] def acquire(self) -> int: self.ref_count += 1 return self.ref_count
[docs] def release(self) -> int: self.ref_count -= 1 if self.ref_count <= 0: self.stopped = True return self.ref_count
[docs] class SQLExperienceStorage(SQLStorage): """Used as trainer input."""
[docs] def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: super().__init__(storage_config, config) self.batch_size = config.train_batch_size self.max_timeout = storage_config.max_read_timeout # TODO: optimize the following logic if storage_config.schema_type == "experience": # NOTE: consistent with the old version of experience buffer self._read_method = self._read_priority else: # SFT / DPO uses FIFO style self._read_method = self._read_fifo
[docs] def write(self, data: List[Experience]) -> None: with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: experience_models = [self.table_model_cls.from_experience(exp) for exp in data] session.add_all(experience_models) self.logger.info(f"Write {len(experience_models)} experiences to SQL storage.")
def _read_fifo(self, batch_size: int) -> List[Experience]: """Read experiences in FIFO order.""" exp_list = [] start_time = time.time() while len(exp_list) < batch_size: if self.stopped: raise StopIteration() if time.time() - start_time > self.max_timeout: self.logger.warning( f"Max read timeout reached ({self.max_timeout} s), only get {len(exp_list)} experiences, stopping..." ) raise StopIteration() with retry_session( self.session, self.max_retry_times, self.max_retry_interval ) as session: # get a batch of experiences from the database experiences = ( session.query(self.table_model_cls) .filter(self.table_model_cls.id > self.offset) .order_by(asc(self.table_model_cls.id)) .limit(batch_size - len(exp_list)) .all() ) if experiences: self.offset = experiences[-1].id start_time = time.time() exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences]) if len(exp_list) < batch_size: self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...") time.sleep(1) return exp_list def _read_priority(self, batch_size: int) -> List[Experience]: exp_list = [] start_time = time.time() latest_size = 0 while latest_size < batch_size: if self.stopped: raise StopIteration() if time.time() - start_time > self.max_timeout: self.logger.warning( f"Max read timeout reached ({self.max_timeout} s), only get {latest_size} experiences, stopping..." ) raise StopIteration() with retry_session( self.session, self.max_retry_times, self.max_retry_interval ) as session: experiences = ( session.query(self.table_model_cls) .order_by(asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) .limit(batch_size) .with_for_update() .all() ) if len(experiences) != batch_size: if latest_size != len(experiences): latest_size = len(experiences) start_time = time.time() else: ids = [exp.id for exp in experiences] session.query(self.table_model_cls).filter( self.table_model_cls.id.in_(ids) ).update( {self.table_model_cls.consumed: self.table_model_cls.consumed + 1}, synchronize_session=False, ) exp_list.extend( [self.table_model_cls.to_experience(exp) for exp in experiences] ) break self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...") time.sleep(1) return exp_list
[docs] def read(self, batch_size: Optional[int] = None) -> List[Experience]: if self.stopped: raise StopIteration() batch_size = batch_size or self.batch_size return self._read_method(batch_size)
[docs] @classmethod def load_from_dataset( cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig ) -> "SQLExperienceStorage": import transformers tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) storage = cls( storage_config=storage_config, config=config, ) formatter = FORMATTER.get(storage_config.schema_type)(tokenizer, storage_config.format) batch_size = storage.batch_size batch = [] for item in dataset: batch.append(formatter.format(item)) if len(batch) >= batch_size: storage.write(batch) batch.clear() if batch: storage.write(batch) return storage
[docs] class SQLTaskStorage(SQLStorage): """Used as explorer input."""
[docs] def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: super().__init__(storage_config, config) self.batch_size = config.batch_size self.is_eval = storage_config.is_eval self.default_workflow_cls = WORKFLOWS.get(storage_config.default_workflow_type) # type: ignore if self.is_eval and storage_config.default_eval_workflow_type: self.default_workflow_cls = WORKFLOWS.get(storage_config.default_eval_workflow_type) self.default_reward_fn_cls = REWARD_FUNCTIONS.get(storage_config.default_reward_fn_type) # type: ignore self.formatter = TaskFormatter(storage_config) self.offset = storage_config.index if storage_config.total_steps: self.total_samples = self.batch_size * storage_config.total_steps else: if storage_config.total_epochs > 1: self.logger.warning( f"SQL Storage do not support total_epochs, the value {storage_config.total_epochs} will be ignored" ) self.total_samples = float("inf")
[docs] def write(self, data: List[Dict]) -> None: with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: tasks = [self.table_model_cls.from_dict(item) for item in data] session.add_all(tasks)
[docs] def read(self, batch_size: Optional[int] = None) -> List[Task]: if self.stopped: raise StopIteration() if self.offset > self.total_samples: raise StopIteration() batch_size = batch_size or self.batch_size with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: query = ( session.query(self.table_model_cls) .filter(self.table_model_cls.id > self.offset) .order_by(asc(self.table_model_cls.id)) .limit(batch_size) ) results = query.all() if len(results) == 0: raise StopIteration() if not self.is_eval and len(results) < batch_size: raise StopIteration() self.offset = results[-1].id return [self.formatter.format(item.raw_task) for item in results]
[docs] @classmethod def load_from_dataset( cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig ) -> "SQLTaskStorage": storage = cls( storage_config=storage_config, config=config, ) batch_size = config.batch_size batch = [] for item in dataset: batch.append(item) if len(batch) >= batch_size: storage.write(batch) batch.clear() if batch: storage.write(batch) return storage