"""A queue implemented by Ray Actor."""
import asyncio
from copy import deepcopy
from typing import List
import ray
from trinity.buffer.writer.file_writer import JSONWriter
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
from trinity.utils.log import get_logger
[docs]
def is_database_url(path: str) -> bool:
return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])
[docs]
def is_json_file(path: str) -> bool:
return path.endswith(".json") or path.endswith(".jsonl")
[docs]
class QueueActor:
"""An asyncio.Queue based queue actor."""
FINISH_MESSAGE = "$FINISH$"
[docs]
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
self.logger = get_logger(__name__)
self.config = config
self.capacity = getattr(config, "capacity", 10000)
self.queue = asyncio.Queue(self.capacity)
st_config = deepcopy(storage_config)
st_config.wrap_in_ray = False
if st_config.path is not None:
if is_database_url(st_config.path):
st_config.storage_type = StorageType.SQL
self.writer = SQLWriter(st_config, self.config)
elif is_json_file(st_config.path):
st_config.storage_type = StorageType.FILE
self.writer = JSONWriter(st_config, self.config)
else:
self.logger.warning("Unknown supported storage path: %s", st_config.path)
self.writer = None
else:
st_config.storage_type = StorageType.FILE
self.writer = JSONWriter(st_config, self.config)
self.logger.warning(f"Save experiences in {st_config.path}.")
self.ref_count = 0
[docs]
async def acquire(self) -> int:
self.ref_count += 1
return self.ref_count
[docs]
async def release(self) -> int:
"""Release the queue."""
self.ref_count -= 1
if self.ref_count <= 0:
await self.queue.put(self.FINISH_MESSAGE)
self.writer.release()
return self.ref_count
[docs]
def length(self) -> int:
"""The length of the queue."""
return self.queue.qsize()
[docs]
async def put_batch(self, exp_list: List) -> None:
"""Put batch of experience."""
await self.queue.put(exp_list)
if self.writer is not None:
self.writer.write(exp_list)
[docs]
async def get_batch(self, batch_size: int) -> List:
"""Get batch of experience."""
batch = []
while True:
exp_list = await self.queue.get()
if exp_list == self.FINISH_MESSAGE:
raise StopAsyncIteration()
batch.extend(exp_list)
if len(batch) >= batch_size:
break
return batch
[docs]
@classmethod
def get_actor(cls, storage_config: StorageConfig, config: BufferConfig):
"""Get the queue actor."""
return (
ray.remote(cls)
.options(
name=f"queue-{storage_config.name}",
namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace,
get_if_exists=True,
)
.remote(storage_config, config)
)