"""Monitor"""
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
import pandas as pd
import wandb
from torch.utils.tensorboard import SummaryWriter
from trinity.common.config import Config
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
MONITOR = Registry("monitor")
[docs]
class Monitor(ABC):
"""Monitor"""
[docs]
def __init__(
self,
project: str,
name: str,
role: str,
config: Config = None, # pass the global Config for recording
) -> None:
self.project = project
self.name = name
self.role = role
self.config = config
[docs]
@abstractmethod
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
"""Log a table"""
[docs]
@abstractmethod
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
[docs]
@abstractmethod
def close(self) -> None:
"""Close the monitor"""
def __del__(self) -> None:
self.close()
[docs]
def calculate_metrics(
self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None
) -> dict[str, float]:
metrics = {}
for key, val in data.items():
if prefix is not None:
key = f"{prefix}/{key}"
if isinstance(val, List):
if len(val) > 1:
metrics[f"{key}/mean"] = np.mean(val)
metrics[f"{key}/max"] = np.amax(val)
metrics[f"{key}/min"] = np.amin(val)
elif len(val) == 1:
metrics[key] = val[0]
else:
metrics[key] = val
return metrics
[docs]
@MONITOR.register_module("tensorboard")
class TensorboardMonitor(Monitor):
[docs]
def __init__(self, project: str, name: str, role: str, config: Config = None) -> None:
self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard")
os.makedirs(self.tensorboard_dir, exist_ok=True)
self.logger = SummaryWriter(self.tensorboard_dir)
self.console_logger = get_logger(__name__)
[docs]
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
pass
[docs]
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
for key in data:
self.logger.add_scalar(key, data[key], step)
[docs]
def close(self) -> None:
self.logger.close()
[docs]
@MONITOR.register_module("wandb")
class WandbMonitor(Monitor):
[docs]
def __init__(self, project: str, name: str, role: str, config: Config = None) -> None:
self.logger = wandb.init(
project=project,
group=name,
name=f"{name}_{role}",
tags=[role],
config=config,
save_code=False,
)
self.console_logger = get_logger(__name__)
[docs]
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
experiences_table = wandb.Table(dataframe=experiences_table)
self.log(data={table_name: experiences_table}, step=step)
[docs]
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step, commit=commit)
self.console_logger.info(f"Step {step}: {data}")
[docs]
def close(self) -> None:
self.logger.finish()