Source code for memoryscope.core.models.base_model

import inspect
import time
import os
from abc import abstractmethod, ABCMeta
from typing import Any

from memoryscope.core.utils.logger import Logger
from memoryscope.core.utils.registry import Registry
from memoryscope.core.utils.timer import Timer
from memoryscope.enumeration.model_enum import ModelEnum
from memoryscope.scheme.model_response import ModelResponse, ModelResponseGen
from memoryscope.core.memoryscope_context import MemoryscopeContext
from memoryscope.core.memoryscope_context import get_memoryscope_uuid

MODEL_REGISTRY = Registry("models")


[docs] class BaseModel(metaclass=ABCMeta): m_type: ModelEnum | None = None
[docs] def __init__(self, model_name: str, module_name: str, timeout: int = None, max_retries: int = 3, retry_interval: float = 1.0, kwargs_filter: bool = True, raise_exception: bool = True, **kwargs): self.model_name: str = model_name self.module_name: str = module_name self.timeout: int = timeout self.max_retries: int = max_retries self.retry_interval: float = retry_interval self.kwargs_filter: bool = kwargs_filter self.raise_exception: bool = raise_exception self.context: MemoryscopeContext = get_memoryscope_uuid() self.kwargs: dict = kwargs self._model: Any = None self.logger = Logger.get_logger("base_model")
@property def model(self): if self._model is None: if self.module_name not in MODEL_REGISTRY.module_dict: raise RuntimeError(f"method_type={self.module_name} is not supported!") obj_cls = MODEL_REGISTRY[self.module_name] if 'openai' in self.module_name: if os.environ.get('OPENAI_API_KEY', None) is None: raise ValueError("Missing openai api key!") if self.kwargs_filter: allowed_kwargs = list(inspect.signature(obj_cls.__init__).parameters.keys()) kwargs = {key: value for key, value in self.kwargs.items() if key in allowed_kwargs} else: kwargs = self.kwargs self._model = obj_cls(**kwargs) return self._model
[docs] @abstractmethod def before_call(self, model_response: ModelResponse, **kwargs): pass
[docs] @abstractmethod def after_call(self, model_response: ModelResponse, **kwargs) -> ModelResponse | ModelResponseGen: pass
@abstractmethod def _call(self, model_response: ModelResponse, stream: bool = False, **kwargs): pass
[docs] def call(self, stream: bool = False, **kwargs) -> ModelResponse | ModelResponseGen: with Timer(self.__class__.__name__, time_log_type="none") as t: model_response = ModelResponse(m_type=self.m_type) self.before_call(stream=stream, model_response=model_response, **kwargs) for i in range(self.max_retries): if self.raise_exception: self._call(stream=stream, model_response=model_response, **kwargs) else: try: self._call(stream=stream, model_response=model_response, **kwargs) except Exception as e: model_response.status = False model_response.details = e.args if isinstance(model_response, ModelResponse) and not model_response.status: self.logger.warning(f"call model={self.model_name} failed! {t.cost_str} retry_cnt={i} " f"details={model_response.details}", stacklevel=2) time.sleep(i * self.retry_interval) else: return self.after_call(stream=stream, model_response=model_response, **kwargs)
@abstractmethod async def _async_call(self, model_response: ModelResponse, **kwargs) -> ModelResponse: pass
[docs] async def async_call(self, **kwargs) -> ModelResponse: with Timer(self.__class__.__name__, time_log_type="none") as t: model_response = ModelResponse(m_type=self.m_type) self.before_call(model_response=model_response, **kwargs) for i in range(self.max_retries): if self.raise_exception: await self._async_call(model_response=model_response, **kwargs) else: try: await self._async_call(model_response=model_response, **kwargs) except Exception as e: model_response.status = False model_response.details = e.args if not model_response.status: self.logger.warning(f"async_call model={self.model_name} failed! {t.cost_str} retry_cnt={i} " f"details={model_response.details}", stacklevel=2) time.sleep(i * self.retry_interval) else: return self.after_call(model_response=model_response, **kwargs)