# -*- coding: utf-8 -*-
"""State manager."""
import json
import os
from typing import Optional
from trinity.common.config import Config, load_config
from trinity.utils.log import get_logger
[docs]
class StateManager:
"""A Manager class for managing the running state of Explorer and Trainer."""
[docs]
def __init__(
self,
path: str,
trainer_name: Optional[str] = None,
explorer_name: Optional[str] = None,
config: Optional[Config] = None,
check_config: bool = False,
):
self.logger = get_logger(__name__, in_ray_actor=True)
self.cache_dir = path
os.makedirs(self.cache_dir, exist_ok=True) # type: ignore
self.stage_state_path = os.path.join(self.cache_dir, "stage_meta.json") # type: ignore
self.explorer_state_path = os.path.join(self.cache_dir, f"{explorer_name}_meta.json") # type: ignore
self.trainer_state_path = os.path.join(self.cache_dir, f"{trainer_name}_meta.json") # type: ignore
if check_config and config is not None:
self._check_config_consistency(config)
def _check_config_consistency(self, config: Config) -> None:
"""Check if the config is consistent with the cache dir backup."""
backup_config_path = os.path.join(self.cache_dir, "config.json") # type: ignore
if not os.path.exists(backup_config_path):
config.save(backup_config_path)
else:
backup_config = load_config(backup_config_path)
if backup_config != config:
self.logger.warning(
f"The current config is inconsistent with the backup config in {backup_config_path}."
)
raise ValueError(
f"The current config is inconsistent with the backup config in {backup_config_path}."
)
[docs]
def save_explorer(
self,
current_task_index: int,
current_step: int,
) -> None:
with open(self.explorer_state_path, "w", encoding="utf-8") as f:
json.dump(
{
"latest_task_index": current_task_index,
"latest_iteration": current_step,
},
f,
indent=2,
)
[docs]
def load_explorer(self) -> dict:
if os.path.exists(self.explorer_state_path):
try:
with open(self.explorer_state_path, "r", encoding="utf-8") as f:
explorer_meta = json.load(f)
self.logger.info(
"----------------------------------\n"
"Found existing explorer checkpoint:\n"
f" > {explorer_meta}\n"
"Continue exploring from this point.\n"
"----------------------------------"
)
return explorer_meta
except Exception as e:
self.logger.error(f"Failed to load explore state file: {e}")
return {}
[docs]
def save_trainer(
self,
current_exp_index: int,
current_step: int,
) -> None:
with open(self.trainer_state_path, "w", encoding="utf-8") as f:
json.dump(
{
"latest_exp_index": current_exp_index,
"latest_iteration": current_step,
},
f,
indent=2,
)
[docs]
def load_trainer(self) -> dict:
if os.path.exists(self.trainer_state_path):
try:
with open(self.trainer_state_path, "r", encoding="utf-8") as f:
trainer_meta = json.load(f)
self.logger.info(
"----------------------------------\n"
"Found existing trainer checkpoint:\n"
f" > {trainer_meta}\n"
"Continue training from this point.\n"
"----------------------------------"
)
return trainer_meta
except Exception as e:
self.logger.warning(f"Failed to load trainer state file: {e}")
return {}
[docs]
def save_stage(self, current_stage: int) -> None:
with open(self.stage_state_path, "w", encoding="utf-8") as f:
json.dump(
{
"latest_stage": current_stage,
},
f,
indent=2,
)
[docs]
def load_stage(self) -> dict:
if os.path.exists(self.stage_state_path):
try:
with open(self.stage_state_path, "r", encoding="utf-8") as f:
stage_meta = json.load(f)
self.logger.info(
"----------------------------------\n"
"Found existing stage checkpoint:\n"
f" > {stage_meta}\n"
"Continue from this point.\n"
"----------------------------------"
)
return stage_meta
except Exception as e:
self.logger.warning(f"Failed to load stage state file: {e}")
return {}