Source code for trinity.buffer.storage.queue

"""Ray Queue storage"""
import asyncio
import time
from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from functools import partial
from typing import List, Optional

import ray
from sortedcontainers import SortedDict

from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry


[docs] def is_database_url(path: str) -> bool: return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])
[docs] def is_json_file(path: str) -> bool: return path.endswith(".json") or path.endswith(".jsonl")
PRIORITY_FUNC = Registry("priority_fn")
[docs] @PRIORITY_FUNC.register_module("linear_decay") def linear_decay_priority(item: List[Experience], decay: float = 0.1): return item[0].info["model_version"] - decay * item[0].info["use_count"] # type: ignore
[docs] class QueueBuffer(ABC):
[docs] @abstractmethod async def put(self, exps: List[Experience]) -> None: """Put a list of experiences into the queue."""
[docs] @abstractmethod async def get(self) -> List[Experience]: """Get a list of experience from the queue."""
[docs] @abstractmethod def qsize(self) -> int: """Get the current size of the queue."""
[docs] @abstractmethod async def close(self) -> None: """Close the queue."""
[docs] @abstractmethod def stopped(self) -> bool: """Check if there is no more data to read."""
[docs] @classmethod def get_queue(cls, storage_config: StorageConfig, config: BufferConfig) -> "QueueBuffer": """Get a queue instance based on the storage configuration.""" logger = get_logger(__name__) if storage_config.use_priority_queue: reuse_cooldown_time = storage_config.reuse_cooldown_time replay_buffer_kwargs = storage_config.replay_buffer_kwargs capacity = min(storage_config.capacity, config.train_batch_size * 2) logger.info( f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {reuse_cooldown_time}." ) return AsyncPriorityQueue(capacity, reuse_cooldown_time, **replay_buffer_kwargs) else: return AsyncQueue(capacity=storage_config.capacity)
[docs] class AsyncQueue(asyncio.Queue, QueueBuffer):
[docs] def __init__(self, capacity: int): """ Initialize the async queue with a specified capacity. Args: capacity (`int`): The maximum number of items the queue can hold. """ super().__init__(maxsize=capacity) self._closed = False
[docs] async def close(self) -> None: """Close the queue.""" self._closed = True for getter in self._getters: if not getter.done(): getter.set_exception(StopAsyncIteration()) self._getters.clear()
[docs] def stopped(self) -> bool: """Check if there is no more data to read.""" return self._closed and self.empty()
[docs] class AsyncPriorityQueue(QueueBuffer): """ An asynchronous priority queue that manages a fixed-size buffer of experience items. Items are prioritized using a user-defined function and reinserted after a cooldown period. Attributes: capacity (int): Maximum number of items the queue can hold. This value is automatically adjusted to be at most twice the read batch size. priority_groups (SortedDict): Maps priorities to deques of items with the same priority. priority_fn (callable): Function used to determine the priority of an item. reuse_cooldown_time (float): Delay before reusing an item (set to infinity to disable). """
[docs] def __init__( self, capacity: int, reuse_cooldown_time: Optional[float] = None, priority_fn: str = "linear_decay", **kwargs, ): """ Initialize the async priority queue. Args: capacity (`int`): The maximum number of items the queue can store. reuse_cooldown_time (`float`): Time to wait before reusing an item. Set to None to disable reuse. priority_fn (`str`): Name of the function to use for determining item priority. kwargs: Additional keyword arguments for the priority function. """ self.capacity = capacity self.priority_groups = SortedDict() # Maps priority -> deque of items self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **kwargs) self.reuse_cooldown_time = reuse_cooldown_time self._condition = asyncio.Condition() # For thread-safe operations self._closed = False
async def _put(self, item: List[Experience], delay: float = 0) -> None: """ Insert an item into the queue, replacing the lowest-priority item if full. Args: item (`List[Experience]`): A list of experiences to add. delay (`float`): Optional delay before insertion (for simulating timing behavior). """ if delay > 0: await asyncio.sleep(delay) if len(item) == 0: return priority = self.priority_fn(item=item) async with self._condition: if len(self.priority_groups) == self.capacity: # If full, only insert if new item has higher or equal priority than the lowest lowest_priority, item_queue = self.priority_groups.peekitem(index=0) if lowest_priority > priority: return # Skip insertion if lower priority # Remove the lowest priority item item_queue.popleft() if not item_queue: self.priority_groups.popitem(index=0) # Add the new item if priority not in self.priority_groups: self.priority_groups[priority] = deque() self.priority_groups[priority].append(item) self._condition.notify()
[docs] async def put(self, item: List[Experience]) -> None: await self._put(item, delay=0)
[docs] async def get(self) -> List[Experience]: """ Retrieve the highest-priority item from the queue. Returns: List[Experience]: The highest-priority item (list of experiences). Notes: - After retrieval, the item is optionally reinserted after a cooldown period. """ async with self._condition: while len(self.priority_groups) == 0: if self._closed: raise StopAsyncIteration() await self._condition.wait() _, item_queue = self.priority_groups.peekitem(index=-1) item = item_queue.popleft() if not item_queue: self.priority_groups.popitem(index=-1) for exp in item: exp.info["use_count"] += 1 # Optionally resubmit the item after a cooldown if self.reuse_cooldown_time is not None: asyncio.create_task(self._put(item, self.reuse_cooldown_time)) return item
[docs] def qsize(self): return len(self.priority_groups)
[docs] async def close(self) -> None: """ Close the queue. """ async with self._condition: self._closed = True # No more items will be added, but existing items can still be processed. self.reuse_cooldown_time = None self._condition.notify_all()
[docs] def stopped(self) -> bool: return self._closed and len(self.priority_groups) == 0
[docs] class QueueStorage: """An wrapper of a async queue."""
[docs] def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger = get_logger(f"queue_{storage_config.name}", in_ray_actor=True) self.config = config self.capacity = storage_config.capacity self.queue = QueueBuffer.get_queue(storage_config, config) st_config = deepcopy(storage_config) st_config.wrap_in_ray = False if st_config.path is not None: if is_database_url(st_config.path): from trinity.buffer.writer.sql_writer import SQLWriter st_config.storage_type = StorageType.SQL self.writer = SQLWriter(st_config, self.config) elif is_json_file(st_config.path): from trinity.buffer.writer.file_writer import JSONWriter st_config.storage_type = StorageType.FILE self.writer = JSONWriter(st_config, self.config) else: self.logger.warning("Unknown supported storage path: %s", st_config.path) self.writer = None else: from trinity.buffer.writer.file_writer import JSONWriter st_config.storage_type = StorageType.FILE self.writer = JSONWriter(st_config, self.config) self.logger.warning(f"Save experiences in {st_config.path}.") self.ref_count = 0 self.exp_pool = deque() # A pool to store experiences self.closed = False
[docs] async def acquire(self) -> int: self.ref_count += 1 return self.ref_count
[docs] async def release(self) -> int: """Release the queue.""" self.ref_count -= 1 if self.ref_count <= 0: await self.queue.close() if self.writer is not None: await self.writer.release() return self.ref_count
[docs] def length(self) -> int: """The length of the queue.""" return self.queue.qsize()
[docs] async def put_batch(self, exp_list: List) -> None: """Put batch of experience.""" await self.queue.put(exp_list) if self.writer is not None: self.writer.write(exp_list)
[docs] async def get_batch(self, batch_size: int, timeout: float) -> List: """Get batch of experience.""" start_time = time.time() while len(self.exp_pool) < batch_size: if self.queue.stopped(): # If the queue is stopped, ignore the rest of the experiences in the pool raise StopAsyncIteration("Queue is closed and no more items to get.") try: exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0) self.exp_pool.extend(exp_list) except asyncio.TimeoutError: if time.time() - start_time > timeout: self.logger.error( f"Timeout when waiting for experience, only get {len(self.exp_pool)} experiences.\n" "This phenomenon is usually caused by the workflow not returning enough " "experiences or running timeout. Please check your workflow implementation." ) batch = list(self.exp_pool) self.exp_pool.clear() return batch return [self.exp_pool.popleft() for _ in range(batch_size)]
[docs] @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): """Get the queue actor.""" return ( ray.remote(cls) .options( name=f"queue-{storage_config.name}", namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, ) .remote(storage_config, config) )