memoryscope.core.config.config_manager 源代码

import json
import os
from dataclasses import fields
from datetime import datetime
from pathlib import Path
from typing import Optional, Literal

import yaml

from memoryscope.core.config.arguments import Arguments
from memoryscope.core.utils.logger import Logger


[文档] class ConfigManager(object):
[文档] def __init__(self, config_path: Optional[str] = None, arguments: Optional[Arguments] = None, demo_config_name: str = "demo_config_zh.yaml", **kwargs): self.config: dict = {} self.kwargs = kwargs self.logger = Logger.get_logger("memoryscope") if not (config_path or kwargs or arguments): raise RuntimeError("can not init config manager without kwargs or --config_path!") if config_path: self.read_config(config_path) else: self.read_config((Path(__file__).parent / demo_config_name).__str__()) kwargs = {k: v for k, v in kwargs.items() if k in [x.name for x in fields(Arguments)]} kwargs_padding = {x.name: None for x in fields(Arguments) if x.name not in kwargs} kwargs.update(kwargs_padding) # (high) when there are environment variables, read them and merge into kwargs kwargs_from_env = {x.name:os.environ.get(x.name, None) for x in fields(Arguments) if os.environ.get(x.name, None) is not None} kwargs.update(kwargs_from_env) # generate argument dataclass if not arguments: arguments = Arguments(**kwargs) else: # (highest) when arguments is passed into the memoryscope arguments = arguments self.update_config_by_arguments(arguments) self.logger.info("\n" + self.dump_config())
[文档] def read_config(self, config_path: str): if config_path.endswith(".yaml"): with open(config_path) as f: self.config = yaml.load(f, yaml.FullLoader) elif config_path.endswith(".json"): with open(config_path) as f: self.config = json.load(f)
[文档] @staticmethod def update_ignore_none(config, new_config_dict): update_dict = {k:v for k, v in new_config_dict.items() if v is not None} config.update(update_dict) return
[文档] @staticmethod def update_global_by_arguments(config: dict, arguments: Arguments): ConfigManager.update_ignore_none( config, { "language": arguments.language, "thread_pool_max_workers": arguments.thread_pool_max_workers, "enable_ranker": arguments.enable_ranker, "enable_today_contra_repeat": arguments.enable_today_contra_repeat, "enable_long_contra_repeat": arguments.enable_long_contra_repeat, "output_memory_max_count": arguments.output_memory_max_count, } )
[文档] @staticmethod def update_memory_chat_by_arguments(config: dict, arguments: Arguments): if arguments.memory_chat_class is not None: memory_chat_class_split = config["class"].split(".") stream = arguments.chat_stream if stream is None: stream = arguments.memory_chat_class in ["cli_memory_chat", ] config.update( { "class": ".".join(memory_chat_class_split[:-1] + [arguments.memory_chat_class]), "stream": stream, } )
[文档] @staticmethod def update_memory_service_by_arguments(config: dict, arguments: Arguments): ConfigManager.update_ignore_none(config, { "human_name": arguments.human_name, "assistant_name": arguments.assistant_name, }) if arguments.consolidate_memory_interval_time is not None: config["memory_operations"]["consolidate_memory"]["interval_time"] = \ arguments.consolidate_memory_interval_time if arguments.reflect_and_reconsolidate_interval_time is not None: config["memory_operations"]["reflect_and_reconsolidate"]["interval_time"] = \ arguments.reflect_and_reconsolidate_interval_time
[文档] @staticmethod def update_worker_by_arguments(config: dict, arguments: Arguments): if arguments.worker_params is not None: for worker_name, kv_dict in arguments.worker_params.items(): if worker_name not in config: continue config[worker_name].update(kv_dict)
[文档] @staticmethod def update_model_by_arguments(config: dict, arguments: Arguments): ConfigManager.update_ignore_none(config["generation_model"], { "module_name": arguments.generation_backend, "model_name": arguments.generation_model, }) if isinstance(arguments.generation_params, dict): ConfigManager.update_ignore_none(config["generation_model"], { **arguments.generation_params, }) ConfigManager.update_ignore_none(config["embedding_model"], { "module_name": arguments.embedding_backend, "model_name": arguments.embedding_model, }) if isinstance(arguments.embedding_params, dict): ConfigManager.update_ignore_none(config["embedding_model"], { **arguments.embedding_params, }) ConfigManager.update_ignore_none(config["rank_model"], { "module_name": arguments.rank_backend, "model_name": arguments.rank_model, }) if isinstance(arguments.rank_params, dict): ConfigManager.update_ignore_none(config["rank_model"], { **arguments.rank_params, })
[文档] @staticmethod def update_memory_store_by_arguments(config: dict, arguments: Arguments): ConfigManager.update_ignore_none(config, { "index_name": arguments.es_index_name, "es_url": arguments.es_url, "retrieve_mode": arguments.retrieve_mode} )
[文档] def update_config_by_arguments(self, arguments: Arguments): # prepare global self.update_global_by_arguments(self.config["global"], arguments) # prepare memory chat memory_chat_conf_dict = self.config["memory_chat"] memory_chat_config = list(memory_chat_conf_dict.values())[0] self.update_memory_chat_by_arguments(memory_chat_config, arguments) # prepare memory service memory_service_conf_dict = self.config["memory_service"] memory_service_config = list(memory_service_conf_dict.values())[0] self.update_memory_service_by_arguments(memory_service_config, arguments) # prepare worker self.update_worker_by_arguments(self.config["worker"], arguments) # prepare model self.update_model_by_arguments(self.config["model"], arguments) # prepare memory store self.update_memory_store_by_arguments(self.config["memory_store"], arguments)
[文档] def add_node_object(self, node: str, name: str, config: dict): self.config[node][name] = config
[文档] def pop_node_object(self, node: str, name: str): return self.config[node].pop(name, None)
[文档] def clear_node_all(self, node: str): self.config[node].clear()
[文档] def dump_config(self, file_type: Literal["json", "yaml"] = "yaml", file_path: Optional[str] = None) -> str: if file_type == "json": content = json.dumps(self.config, indent=2, ensure_ascii=False) elif file_type == "yaml": content = yaml.dump(self.config, indent=2, allow_unicode=True) else: raise NotImplementedError if file_path: with open(file_path, "w") as f: f.write(content) return content