Source code for trinity.manager.manager

# -*- coding: utf-8 -*-
"""Data manager."""
import json
import os

from trinity.common.config import Config, load_config
from trinity.utils.log import get_logger

logger = get_logger(__name__)


[docs] class CacheManager: """A Manager class for managing the cache dir."""
[docs] def __init__(self, config: Config, check_config: bool = False): self.cache_dir = config.monitor.cache_dir # type: ignore self.explorer_meta_path = os.path.join(self.cache_dir, "explorer_meta.json") # type: ignore self.trainer_meta_path = os.path.join(self.cache_dir, "trainer_meta.json") # type: ignore if check_config: self._check_config_consistency(config)
def _check_config_consistency(self, config: Config) -> None: """Check if the config is consistent with the cache dir backup.""" backup_config_path = os.path.join(self.cache_dir, "config.json") # type: ignore if not os.path.exists(backup_config_path): config.save(backup_config_path) else: backup_config = load_config(backup_config_path) if backup_config != config: logger.warning( f"The current config is inconsistent with the backup config in {backup_config_path}." ) raise ValueError( f"The current config is inconsistent with the backup config in {backup_config_path}." )
[docs] def save_explorer(self, current_task_index: int, current_step: int) -> None: with open(self.explorer_meta_path, "w", encoding="utf-8") as f: json.dump( {"latest_task_index": current_task_index, "latest_iteration": current_step}, f, indent=2, )
[docs] def load_explorer(self) -> dict: if os.path.exists(self.explorer_meta_path): try: with open(self.explorer_meta_path, "r", encoding="utf-8") as f: explorer_meta = json.load(f) logger.info(f"Find existing explorer meta: {explorer_meta}") return explorer_meta except Exception as e: logger.error(f"Failed to load explore meta file: {e}") return {}
[docs] def save_trainer(self, current_step: int) -> None: with open(self.trainer_meta_path, "w", encoding="utf-8") as f: json.dump({"latest_iteration": current_step}, f, indent=2)
[docs] def load_trainer(self) -> dict: if os.path.exists(self.trainer_meta_path): try: with open(self.trainer_meta_path, "r", encoding="utf-8") as f: trainer_meta = json.load(f) logger.info(f"Find existing trainer meta: {trainer_meta}") return trainer_meta except Exception as e: logger.warning(f"Failed to load trainer meta file: {e}") return {}