"""Monitor"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
try:
import wandb
except ImportError:
wandb = None
try:
import mlflow
except ImportError:
mlflow = None
from torch.utils.tensorboard import SummaryWriter
from trinity.common.config import Config
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
MONITOR = Registry("monitor")
[docs]
def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict:
df = pd.DataFrame(metric_list)
numeric_df = df.select_dtypes(include=[np.number])
stats_df = numeric_df.agg(["mean", "max", "min"])
metric = {}
for col in stats_df.columns:
metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col]
metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col]
metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col]
return metric
[docs]
class Monitor(ABC):
"""Monitor"""
[docs]
def __init__(
self,
project: str,
name: str,
role: str,
config: Config = None, # pass the global Config for recording
) -> None:
self.project = project
self.name = name
self.role = role
self.config = config
[docs]
@abstractmethod
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
"""Log a table"""
[docs]
@abstractmethod
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
[docs]
@abstractmethod
def close(self) -> None:
"""Close the monitor"""
def __del__(self) -> None:
self.close()
[docs]
def calculate_metrics(
self, data: dict[str, Union[List[float], float]], prefix: Optional[str] = None
) -> dict[str, float]:
metrics = {}
for key, val in data.items():
if prefix is not None:
key = f"{prefix}/{key}"
if isinstance(val, List):
if len(val) > 1:
metrics[f"{key}/mean"] = np.mean(val)
metrics[f"{key}/max"] = np.amax(val)
metrics[f"{key}/min"] = np.amin(val)
elif len(val) == 1:
metrics[key] = val[0]
else:
metrics[key] = val
return metrics
[docs]
@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {}
[docs]
@MONITOR.register_module("tensorboard")
class TensorboardMonitor(Monitor):
[docs]
def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role)
os.makedirs(self.tensorboard_dir, exist_ok=True)
self.logger = SummaryWriter(self.tensorboard_dir)
self.console_logger = get_logger(__name__)
[docs]
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
pass
[docs]
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
for key in data:
self.logger.add_scalar(key, data[key], step)
self.console_logger.info(f"Step {step}: {data}")
[docs]
def close(self) -> None:
self.logger.close()
[docs]
@MONITOR.register_module("wandb")
class WandbMonitor(Monitor):
"""Monitor with Weights & Biases.
Args:
base_url (`Optional[str]`): The base URL of the W&B server. If not provided, use the environment variable `WANDB_BASE_URL`.
api_key (`Optional[str]`): The API key for W&B. If not provided, use the environment variable `WANDB_API_KEY`.
"""
[docs]
def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
assert wandb is not None, "wandb is not installed. Please install it to use WandbMonitor."
if not group:
group = name
monitor_args = config.monitor.monitor_args or {}
if base_url := monitor_args.get("base_url"):
os.environ["WANDB_BASE_URL"] = base_url
if api_key := monitor_args.get("api_key"):
os.environ["WANDB_API_KEY"] = api_key
self.logger = wandb.init(
project=project,
group=group,
name=f"{name}_{role}",
tags=[role],
config=config,
save_code=False,
)
self.console_logger = get_logger(__name__)
[docs]
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
experiences_table = wandb.Table(dataframe=experiences_table)
self.log(data={table_name: experiences_table}, step=step)
[docs]
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step, commit=commit)
self.console_logger.info(f"Step {step}: {data}")
[docs]
def close(self) -> None:
self.logger.finish()
[docs]
@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {
"base_url": None,
"api_key": None,
}
[docs]
@MONITOR.register_module("mlflow")
class MlflowMonitor(Monitor):
"""Monitor with MLflow.
Args:
uri (`Optional[str]`): The tracking server URI. If not provided, the default is `http://localhost:5000`.
username (`Optional[str]`): The username to login. If not provided, the default is `None`.
password (`Optional[str]`): The password to login. If not provided, the default is `None`.
"""
[docs]
def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
assert (
mlflow is not None
), "mlflow is not installed. Please install it to use MlflowMonitor."
monitor_args = config.monitor.monitor_args or {}
if username := monitor_args.get("username"):
os.environ["MLFLOW_TRACKING_USERNAME"] = username
if password := monitor_args.get("password"):
os.environ["MLFLOW_TRACKING_PASSWORD"] = password
mlflow.set_tracking_uri(config.monitor.monitor_args.get("uri", "http://localhost:5000"))
mlflow.set_experiment(project)
mlflow.start_run(
run_name=f"{name}_{role}",
tags={
"group": group,
"role": role,
},
)
mlflow.log_params(config.flatten())
self.console_logger = get_logger(__name__)
[docs]
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
pass
[docs]
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
mlflow.log_metrics(metrics=data, step=step)
self.console_logger.info(f"Step {step}: {data}")
[docs]
def close(self) -> None:
mlflow.end_run()
[docs]
@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {
"uri": "http://localhost:5000",
"username": None,
"password": None,
}