# Copyright (c) Alibaba, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------
# Most of the code here has been modified from:
# https://github.com/modelscope/modelscope/blob/master/modelscope/utils/registry.py
# --------------------------------------------------------
from loguru import logger
[docs]
class Registry(object):
"""This class is used to register some modules to registry by a repo
name."""
[docs]
def __init__(self, name: str):
"""
Initialization method.
:param name: a registry repo name
"""
self._name = name
self._modules = {}
@property
def name(self):
"""
Get name of current registry.
:return: name of current registry.
"""
return self._name
@property
def modules(self):
"""
Get all modules in current registry.
:return: a dict storing modules in current registry.
"""
return self._modules
[docs]
def list(self):
"""Logging the list of module in current registry."""
for m in self._modules.keys():
logger.info(f'{self._name}\t{m}')
[docs]
def get(self, module_key):
"""
Get module named module_key from in current registry. If not found,
return None.
:param module_key: specified module name
:return: module named module_key
"""
return self._modules.get(module_key, None)
def _register_module(self, module_name=None, module_cls=None, force=False):
"""
Register module to registry.
:param module_name: module name
:param module_cls: module class object
:param force: Whether to override an existing class with the
same name. Default: False.
"""
if module_name is None:
module_name = module_cls.__name__
if module_name in self._modules and not force:
raise KeyError(
f'{module_name} is already registered in {self._name}')
self._modules[module_name] = module_cls
module_cls._name = module_name
[docs]
def register_module(self,
module_name: str = None,
module_cls: type = None,
force=False):
"""
Register module class object to registry with the specified modulename.
:param module_name: module name
:param module_cls: module class object
:param force: Whether to override an existing class with
the same name. Default: False.
Example:
>>> registry = Registry()
>>> @registry.register_module()
>>> class TextFormatter:
>>> pass
>>> class TextFormatter2:
>>> pass
>>> registry.register_module( module_name='text_formatter2',
module_cls=TextFormatter2)
"""
if not (module_name is None or isinstance(module_name, str)):
raise TypeError(f'module_name must be either of None, str,'
f'got {type(module_name)}')
if module_cls is not None:
self._register_module(module_name=module_name,
module_cls=module_cls,
force=force)
return module_cls
# if module_cls is None, should return a decorator function
def _register(module_cls):
"""
Register module class object to registry.
:param module_cls: module class object
:return: module class object.
"""
self._register_module(module_name=module_name,
module_cls=module_cls,
force=force)
return module_cls
return _register