Source code for trinity.buffer.reader.sql_reader
"""Reader of the SQL buffer."""
from typing import Dict, List, Optional
import ray
from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.reader.reader import READER
from trinity.buffer.storage.sql import SQLStorage
from trinity.common.config import StorageConfig
from trinity.common.constants import StorageType
[docs]
@READER.register_module("sql")
class SQLReader(BufferReader):
"""Reader of the SQL buffer."""
[docs]
def __init__(self, config: StorageConfig) -> None:
assert config.storage_type == StorageType.SQL.value
self.wrap_in_ray = config.wrap_in_ray
self.storage = SQLStorage.get_wrapper(config)
[docs]
def read(self, batch_size: Optional[int] = None) -> List:
if self.wrap_in_ray:
return ray.get(self.storage.read.remote(batch_size))
else:
return self.storage.read(batch_size)
[docs]
async def read_async(self, batch_size: Optional[int] = None) -> List:
if self.wrap_in_ray:
try:
return ray.get(self.storage.read.remote(batch_size))
except StopIteration:
raise StopAsyncIteration
else:
return self.storage.read(batch_size)
[docs]
def state_dict(self) -> Dict:
# SQL Not supporting state dict yet
return {"current_index": 0}
[docs]
def load_state_dict(self, state_dict):
# SQL Not supporting state dict yet
return None