Source code for trinity.manager.config_registry.config_registry

from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set

import streamlit as st

from trinity.utils.registry import Registry


[docs] class ConfigRegistry(Registry): """ A registry for managing configuration settings and their associated functions. """
[docs] def __init__(self, name: str): super().__init__(name) self._default_config = {} # Stores default values for configs self._config_visibles = {} # Stores visibles for config visibility self.unfinished_fields = set()
[docs] def set_unfinished_fields(self, unfinished_fields: set): """ Set the unfinished fields to track incomplete configurations. Args: unfinished_fields (set): Set of field names that are not yet configured. """ self.unfinished_fields = unfinished_fields
@property def default_config(self) -> dict: """ Get the dictionary of default configuration values. """ return self._default_config
[docs] def get(self, config_name: str): """ Retrieve a configuration function if its visible is met (if any). Args: config_name (str): Name of the configuration to retrieve. Returns: The configuration function if visibles are met, else None. """ if config_name in self._config_visibles: if not self._config_visibles[config_name](): return None return super().get(config_name)
[docs] def get_check_func(self, config_name: str): """ Get the check function associated with a configuration. Args: config_name (str): Name of the configuration. Returns: The check function for the specified configuration. """ check_func_name = f"check_{config_name}" return super().get(check_func_name)
[docs] def get_configs(self, *config_names: str, columns_spec: List[int] = None): """ Retrieve and display multiple configurations in Streamlit columns. Args: *config_names (str): Names of configurations to retrieve. columns_spec (List[int], optional): Configuration for Streamlit columns. """ config_pair = [] for config_name in config_names: config_func = self.get(config_name) if config_func is not None: config_pair.append((config_name, config_func)) if len(config_pair) == 0: return if columns_spec is None: columns_spec = len(config_pair) columns = st.columns(columns_spec) for col, (_, config_func) in zip(columns, config_pair): with col: config_func() for config_name, _ in config_pair: check_func = self.get_check_func(config_name) if check_func is not None: check_func(unfinished_fields=self.unfinished_fields)
def _register_config( self, config_name: str, config_func: Callable[[None], None], default_value: Optional[Any] = None, visible: Optional[Callable[[], bool]] = None, other_configs: Optional[Dict[str, Any]] = None, ): """ Internal method to register a configuration and its associated function. Args: config_name (str): Name of the configuration. config_func (Callable): Function to set the configuration. default_value (Any, optional): Default value for the configuration. visible (Callable, optional): visible for when the config should be visible/applicable. other_configs (Dict[str, Any], optional): Additional configurations to register. """ assert config_name not in self._default_config, f"{config_name} already exists." self._default_config[config_name] = default_value if visible is not None: self._config_visibles[config_name] = visible if other_configs is not None: for name, value in other_configs.items(): assert name not in self._default_config, f"{name} already exists." self._default_config[name] = value super()._register_module(module_name=config_name, module_cls=config_func)
[docs] def register_config( self, default_value: Optional[Any] = None, config_func: Optional[Callable[[None], None]] = None, visible: Optional[Callable[[], bool]] = None, other_configs: Optional[Dict[str, Any]] = None, ): """ Decorator to register a configuration function. The function name must start with 'set_', and the part after 'set_' becomes the config name. Note: This function will automatically pass `key=config_name` as an argument to the registered configuration function. Ensure your function accepts this keyword argument. Args: default_value (Any, optional): Default value for the configuration. config_func (Callable, optional): The configuration function to register. visible (Callable, optional): visible for when the config should be visible. other_configs (Dict[str, Any], optional): Additional configurations to register. Returns: A decorator function if config_func is None, else the registered config function. """ # if config_func is None, should return a decorator function def _register(config_func: Callable[[None], None]): config_name = config_func.__name__ prefix = "set_" assert config_name.startswith( prefix ), f"Config function name should start with `{prefix}`, got {config_name}" config_name = config_name[len(prefix) :] config_func = partial(config_func, key=config_name) self._register_config( config_name=config_name, config_func=config_func, default_value=default_value, visible=visible, other_configs=other_configs, ) return config_func if config_func is not None: return _register(config_func) return _register
def _register_check(self, config_name: str, check_func: Callable[[Set, str], None]): """ Internal method to register a check function for a configuration. Args: config_name (str): Name of the configuration to check. check_func (Callable): Function to check the configuration. """ assert config_name in self._default_config, f"`{config_name}` is not registered." super()._register_module(module_name=f"check_{config_name}", module_cls=check_func)
[docs] def register_check(self, check_func: Callable[[Set, str], None] = None): """ Decorator to register a check function for a configuration. The function name must start with 'check_', and the part after 'check_' should match a config name. Note: This function will automatically pass `key=config_name` and `unfinished_fields=self.unfinished_fields` as an argument to the registered check function. Ensure your function accepts these keyword arguments. Args: check_func (Callable, optional): The check function to register. Returns: A decorator function if check_func is None, else the registered check function. """ def _register(check_func: Callable[[Set, str], None]): config_name = check_func.__name__ prefix = "check_" assert config_name.startswith( prefix ), f"Check function name must start with `{prefix}`, got {config_name}" config_name = config_name[len(prefix) :] check_func = partial(check_func, key=config_name) self._register_check(config_name, check_func) return check_func if check_func is not None: return _register(check_func) return _register
# Global registry for configuration generators CONFIG_GENERATORS = ConfigRegistry("config_generators")