Source code for trinity.buffer.queue

"""A queue implemented by Ray Actor."""
import asyncio
from copy import deepcopy
from typing import List

import ray

from trinity.buffer.writer.file_writer import JSONWriter
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
from trinity.utils.log import get_logger


[docs] def is_database_url(path: str) -> bool: return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])
[docs] def is_json_file(path: str) -> bool: return path.endswith(".json") or path.endswith(".jsonl")
[docs] class QueueActor: """An asyncio.Queue based queue actor.""" FINISH_MESSAGE = "$FINISH$"
[docs] def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger = get_logger(__name__) self.config = config self.capacity = getattr(config, "capacity", 10000) self.queue = asyncio.Queue(self.capacity) st_config = deepcopy(storage_config) st_config.wrap_in_ray = False if st_config.path is not None: if is_database_url(st_config.path): st_config.storage_type = StorageType.SQL self.writer = SQLWriter(st_config, self.config) elif is_json_file(st_config.path): st_config.storage_type = StorageType.FILE self.writer = JSONWriter(st_config, self.config) else: self.logger.warning("Unknown supported storage path: %s", st_config.path) self.writer = None else: st_config.storage_type = StorageType.FILE self.writer = JSONWriter(st_config, self.config) self.logger.warning(f"Save experiences in {st_config.path}.") self.ref_count = 0
[docs] async def acquire(self) -> int: self.ref_count += 1 return self.ref_count
[docs] async def release(self) -> int: """Release the queue.""" self.ref_count -= 1 if self.ref_count <= 0: await self.queue.put(self.FINISH_MESSAGE) self.writer.release() return self.ref_count
[docs] def length(self) -> int: """The length of the queue.""" return self.queue.qsize()
[docs] async def put_batch(self, exp_list: List) -> None: """Put batch of experience.""" await self.queue.put(exp_list) if self.writer is not None: self.writer.write(exp_list)
[docs] async def get_batch(self, batch_size: int) -> List: """Get batch of experience.""" batch = [] while True: exp_list = await self.queue.get() if exp_list == self.FINISH_MESSAGE: raise StopAsyncIteration() batch.extend(exp_list) if len(batch) >= batch_size: break return batch
[docs] @classmethod def get_actor(cls, storage_config: StorageConfig, config: BufferConfig): """Get the queue actor.""" return ( ray.remote(cls) .options( name=f"queue-{storage_config.name}", namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, ) .remote(storage_config, config) )