"""Implementation of async queue buffers."""
import asyncio
from abc import ABC, abstractmethod
from collections import deque
from functools import partial
from typing import List, Optional
from sortedcontainers import SortedDict
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
PRIORITY_FUNC = Registry("priority_fn")
[docs]
@PRIORITY_FUNC.register_module("linear_decay")
def linear_decay_priority(item: List[Experience], decay: float = 0.1):
return item[0].info["model_version"] - decay * item[0].info["use_count"] # type: ignore
[docs]
class QueueBuffer(ABC):
[docs]
@abstractmethod
async def put(self, exps: List[Experience]) -> None:
"""Put a list of experiences into the queue."""
[docs]
@abstractmethod
async def get(self) -> List[Experience]:
"""Get a list of experience from the queue."""
[docs]
@abstractmethod
def qsize(self) -> int:
"""Get the current size of the queue."""
[docs]
@abstractmethod
async def close(self) -> None:
"""Close the queue."""
[docs]
@abstractmethod
def stopped(self) -> bool:
"""Check if there is no more data to read."""
[docs]
@classmethod
def get_queue(cls, storage_config: StorageConfig, config: BufferConfig) -> "QueueBuffer":
"""Get a queue instance based on the storage configuration."""
logger = get_logger(__name__)
if storage_config.use_priority_queue:
reuse_cooldown_time = storage_config.reuse_cooldown_time
replay_buffer_kwargs = storage_config.replay_buffer_kwargs
capacity = min(storage_config.capacity, config.train_batch_size * 2)
logger.info(
f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {reuse_cooldown_time}."
)
return AsyncPriorityQueue(capacity, reuse_cooldown_time, **replay_buffer_kwargs)
else:
return AsyncQueue(capacity=storage_config.capacity)
[docs]
class AsyncQueue(asyncio.Queue, QueueBuffer):
[docs]
def __init__(self, capacity: int):
"""
Initialize the async queue with a specified capacity.
Args:
capacity (`int`): The maximum number of items the queue can hold.
"""
super().__init__(maxsize=capacity)
self._closed = False
[docs]
async def close(self) -> None:
"""Close the queue."""
self._closed = True
for getter in self._getters:
if not getter.done():
getter.set_exception(StopAsyncIteration())
self._getters.clear()
[docs]
def stopped(self) -> bool:
"""Check if there is no more data to read."""
return self._closed and self.empty()
[docs]
class AsyncPriorityQueue(QueueBuffer):
"""
An asynchronous priority queue that manages a fixed-size buffer of experience items.
Items are prioritized using a user-defined function and reinserted after a cooldown period.
Attributes:
capacity (int): Maximum number of items the queue can hold. This value is automatically
adjusted to be at most twice the read batch size.
priority_groups (SortedDict): Maps priorities to deques of items with the same priority.
priority_fn (callable): Function used to determine the priority of an item.
reuse_cooldown_time (float): Delay before reusing an item (set to infinity to disable).
"""
[docs]
def __init__(
self,
capacity: int,
reuse_cooldown_time: Optional[float] = None,
priority_fn: str = "linear_decay",
**kwargs,
):
"""
Initialize the async priority queue.
Args:
capacity (`int`): The maximum number of items the queue can store.
reuse_cooldown_time (`float`): Time to wait before reusing an item. Set to None to disable reuse.
priority_fn (`str`): Name of the function to use for determining item priority.
kwargs: Additional keyword arguments for the priority function.
"""
self.capacity = capacity
self.priority_groups = SortedDict() # Maps priority -> deque of items
self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **kwargs)
self.reuse_cooldown_time = reuse_cooldown_time
self._condition = asyncio.Condition() # For thread-safe operations
self._closed = False
async def _put(self, item: List[Experience], delay: float = 0) -> None:
"""
Insert an item into the queue, replacing the lowest-priority item if full.
Args:
item (`List[Experience]`): A list of experiences to add.
delay (`float`): Optional delay before insertion (for simulating timing behavior).
"""
if delay > 0:
await asyncio.sleep(delay)
priority = self.priority_fn(item=item)
async with self._condition:
if len(self.priority_groups) == self.capacity:
# If full, only insert if new item has higher or equal priority than the lowest
lowest_priority, item_queue = self.priority_groups.peekitem(index=0)
if lowest_priority > priority:
return # Skip insertion if lower priority
# Remove the lowest priority item
item_queue.popleft()
if not item_queue:
self.priority_groups.popitem(index=0)
# Add the new item
if priority not in self.priority_groups:
self.priority_groups[priority] = deque()
self.priority_groups[priority].append(item)
self._condition.notify()
[docs]
async def put(self, item: List[Experience]) -> None:
await self._put(item, delay=0)
[docs]
async def get(self) -> List[Experience]:
"""
Retrieve the highest-priority item from the queue.
Returns:
List[Experience]: The highest-priority item (list of experiences).
Notes:
- After retrieval, the item is optionally reinserted after a cooldown period.
"""
async with self._condition:
while len(self.priority_groups) == 0:
if self._closed:
raise StopAsyncIteration()
await self._condition.wait()
_, item_queue = self.priority_groups.peekitem(index=-1)
item = item_queue.popleft()
if not item_queue:
self.priority_groups.popitem(index=-1)
for exp in item:
exp.info["use_count"] += 1
# Optionally resubmit the item after a cooldown
if self.reuse_cooldown_time is not None:
asyncio.create_task(self._put(item, self.reuse_cooldown_time))
return item
[docs]
def qsize(self):
return len(self.priority_groups)
[docs]
async def close(self) -> None:
"""
Close the queue.
"""
async with self._condition:
self._closed = True
# No more items will be added, but existing items can still be processed.
self.reuse_cooldown_time = None
self._condition.notify_all()
[docs]
def stopped(self) -> bool:
return self._closed and len(self.priority_groups) == 0