Source code for trinity.utils.monitor

"""Monitor"""

import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import wandb
from torch.utils.tensorboard import SummaryWriter

from trinity.common.config import Config
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

MONITOR = Registry("monitor")


[docs] class Monitor(ABC): """Monitor"""
[docs] def __init__( self, project: str, name: str, role: str, config: Config = None, # pass the global Config for recording ) -> None: self.project = project self.name = name self.role = role self.config = config
[docs] @abstractmethod def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): """Log a table"""
[docs] @abstractmethod def log(self, data: dict, step: int, commit: bool = False) -> None: """Log metrics."""
[docs] @abstractmethod def close(self) -> None: """Close the monitor"""
def __del__(self) -> None: self.close()
[docs] def calculate_metrics( self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None ) -> dict[str, float]: metrics = {} for key, val in data.items(): if prefix is not None: key = f"{prefix}/{key}" if isinstance(val, List): if len(val) > 1: metrics[f"{key}/mean"] = np.mean(val) metrics[f"{key}/max"] = np.amax(val) metrics[f"{key}/min"] = np.amin(val) elif len(val) == 1: metrics[key] = val[0] else: metrics[key] = val return metrics
[docs] @MONITOR.register_module("tensorboard") class TensorboardMonitor(Monitor):
[docs] def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard") os.makedirs(self.tensorboard_dir, exist_ok=True) self.logger = SummaryWriter(self.tensorboard_dir) self.console_logger = get_logger(__name__)
[docs] def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): pass
[docs] def log(self, data: dict, step: int, commit: bool = False) -> None: """Log metrics.""" for key in data: self.logger.add_scalar(key, data[key], step)
[docs] def close(self) -> None: self.logger.close()
[docs] @MONITOR.register_module("wandb") class WandbMonitor(Monitor):
[docs] def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: self.logger = wandb.init( project=project, group=name, name=f"{name}_{role}", tags=[role], config=config, save_code=False, ) self.console_logger = get_logger(__name__)
[docs] def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): experiences_table = wandb.Table(dataframe=experiences_table) self.log(data={table_name: experiences_table}, step=step)
[docs] def log(self, data: dict, step: int, commit: bool = False) -> None: """Log metrics.""" self.logger.log(data, step=step, commit=commit) self.console_logger.info(f"Step {step}: {data}")
[docs] def close(self) -> None: self.logger.finish()