Source code for trinity.buffer.storage.file

"""File Storage"""
import json
import os
from typing import List

import ray

from trinity.buffer.utils import default_storage_path
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.experience import EID, Experience
from trinity.common.workflows import Task


class _Encoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, Experience):
            return o.to_dict()
        if isinstance(o, Task):
            return o.to_dict()
        if isinstance(o, EID):
            return o.to_dict()
        return super().default(o)


[docs] class FileStorage: """ A wrapper of a local jsonl file. If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor, and provide a remote interface to the local file. This wrapper is only for writing, if you want to read from the file, use StorageType.QUEUE instead. """
[docs] def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: if storage_config.path is None: storage_config.path = default_storage_path(storage_config, config) ext = os.path.splitext(storage_config.path)[-1] if ext != ".jsonl" and ext != ".json": raise ValueError( f"File path must end with '.json' or '.jsonl', got {storage_config.path}" ) path_dir = os.path.dirname(os.path.abspath(storage_config.path)) os.makedirs(path_dir, exist_ok=True) self.file = open(storage_config.path, "a", encoding="utf-8") self.encoder = _Encoder(ensure_ascii=False) self.ref_count = 0
[docs] @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): if storage_config.wrap_in_ray: return ( ray.remote(cls) .options( name=f"json-{storage_config.name}", namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, ) .remote(storage_config, config) ) else: return cls(storage_config, config)
[docs] def write(self, data: List) -> None: for item in data: json_str = self.encoder.encode(item) self.file.write(json_str + "\n") self.file.flush()
[docs] def read(self) -> List: raise NotImplementedError( "read() is not implemented for FILE Storage, please use QUEUE instead" )
[docs] def acquire(self) -> int: self.ref_count += 1 return self.ref_count
[docs] def release(self) -> int: self.ref_count -= 1 if self.ref_count <= 0: self.file.close() return self.ref_count