Source code for data_juicer.utils.model_utils

import fnmatch
import inspect
import io
import os
from contextlib import redirect_stderr
from functools import partial
from pickle import UnpicklingError
from typing import Optional, Union

import httpx
import multiprocess as mp
import wget
from loguru import logger

from data_juicer import cuda_device_count
from data_juicer.utils.common_utils import nested_access
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.nltk_utils import (
    ensure_nltk_resource,
    patch_nltk_pickle_security,
)

from .cache_utils import DATA_JUICER_MODELS_CACHE as DJMC

torch = LazyLoader("torch")
transformers = LazyLoader("transformers")
nn = LazyLoader("torch.nn")
fasttext = LazyLoader("fasttext", "fasttext-wheel")
sentencepiece = LazyLoader("sentencepiece")
kenlm = LazyLoader("kenlm")
nltk = LazyLoader("nltk")
aes_pred = LazyLoader("aesthetics_predictor", "simple-aesthetics-predictor")
vllm = LazyLoader("vllm")
diffusers = LazyLoader("diffusers")
ram = LazyLoader("ram", "git+https://github.com/xinyu1205/recognize-anything.git")
cv2 = LazyLoader("cv2", "opencv-python")
openai = LazyLoader("openai")
ultralytics = LazyLoader("ultralytics")
tiktoken = LazyLoader("tiktoken")
dashscope = LazyLoader("dashscope")

MODEL_ZOO = {}

# Default cached models links for downloading
MODEL_LINKS = "https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/" "data_juicer/models/"

# Backup cached models links for downloading
BACKUP_MODEL_LINKS = {
    # language identification model from fasttext
    "lid.176.bin": "https://dl.fbaipublicfiles.com/fasttext/supervised-models/",
    # tokenizer and language model for English from sentencepiece and KenLM
    "*.sp.model": "https://huggingface.co/edugp/kenlm/resolve/main/wikipedia/",
    "*.arpa.bin": "https://huggingface.co/edugp/kenlm/resolve/main/wikipedia/",
    # sentence split model from nltk punkt
    "punkt.*.pickle": "https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/" "data_juicer/models/",
    # ram
    "ram_plus_swin_large_14m.pth": "http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/models/"
    "ram_plus_swin_large_14m.pth",
    # FastSAM
    "FastSAM-s.pt": "https://github.com/ultralytics/assets/releases/download/v8.2.0/" "FastSAM-s.pt",
    "FastSAM-x.pt": "https://github.com/ultralytics/assets/releases/download/v8.2.0/" "FastSAM-x.pt",
    # spacy
    "*_core_web_md-3.*.0": "https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/" "data_juicer/models/",
}






[docs] def check_model(model_name, force=False): """ Check whether a model exists in DATA_JUICER_MODELS_CACHE. If exists, return its full path. Else, download it from cached models links. :param model_name: a specified model name :param force: Whether to download model forcefully or not, Sometimes the model file maybe incomplete for some reason, so need to download again forcefully. """ # check for local model if not force and os.path.exists(model_name): return model_name if not os.path.exists(DJMC): os.makedirs(DJMC) # check if the specified model exists. If it does not exist, download it cached_model_path = os.path.join(DJMC, model_name) if force: if os.path.exists(cached_model_path): os.remove(cached_model_path) logger.info(f"Model [{cached_model_path}] is invalid. Forcing download...") else: logger.info(f"Model [{cached_model_path}] is not found. Downloading...") model_link = os.path.join(MODEL_LINKS, model_name) try: wget.download(model_link, cached_model_path) except: # noqa: E722 backup_model_link = get_backup_model_link(model_name) if backup_model_link is not None: backup_model_link = os.path.join(backup_model_link, model_name) try: wget.download(backup_model_link, cached_model_path) except: # noqa: E722 import traceback traceback.print_exc() raise RuntimeError( f"Downloading model [{model_name}] error. " f"Please retry later or download it into {DJMC} " f"manually from {model_link} or {backup_model_link} " ) return cached_model_path
[docs] def filter_arguments(func, args_dict): """ Filters and returns only the valid arguments for a given function signature. :param func: The function or callable to inspect. :param args_dict: A dictionary of argument names and values to filter. :return: A dictionary containing only the arguments that match the function's signature, preserving any **kwargs if applicable. """ params = inspect.signature(func).parameters filtered_args = {} for name, param in params.items(): if param.kind == inspect.Parameter.VAR_KEYWORD: return args_dict if name not in {"self", "cls"} and name in args_dict: filtered_args[name] = args_dict[name] return filtered_args
[docs] class ChatAPIModel:
[docs] def __init__(self, model, endpoint=None, response_path=None, **kwargs): """ Initializes an instance of the APIModel class. :param model: The name of the model to be used for making API calls. This should correspond to a valid model identifier recognized by the API server. :param endpoint: The URL endpoint for the API. If provided as a relative path, it will be appended to the base URL (defined by the `OPENAI_BASE_URL` environment variable or through an additional `base_url` parameter). Defaults to '/chat/completions' for OpenAI compatibility. :param response_path: A dot-separated string specifying the path to extract the desired content from the API response. The default value is 'choices.0.message.content', which corresponds to the typical structure of an OpenAI API response. :param kwargs: Additional keyword arguments for configuring the internal OpenAI client. """ self.model = model self.endpoint = endpoint or "/chat/completions" self.response_path = response_path or "choices.0.message.content" client_args = filter_arguments(openai.OpenAI, kwargs) self._client = openai.OpenAI(**client_args)
def __call__(self, messages, **kwargs): """ Sends messages to the configured API model and returns the parsed response content. :param messages: A list of message dictionaries to send to the API. Each message should have a 'role' (e.g., 'user', 'assistant') and 'content' (the message text). :param kwargs: Additional parameters for the API call. :return: The parsed response content from the API call, or an empty string if an error occurs. """ body = { "messages": messages, "model": self.model, } body.update(kwargs) stream = kwargs.get("stream", False) stream_cls = openai.Stream[openai.types.chat.ChatCompletionChunk] try: response = self._client.post( self.endpoint, body=body, cast_to=httpx.Response, stream=stream, stream_cls=stream_cls ) result = response.json() return nested_access(result, self.response_path) except Exception as e: logger.exception(e) return ""
[docs] class EmbeddingAPIModel:
[docs] def __init__(self, model, endpoint=None, response_path=None, **kwargs): """ Initializes an instance specialized for embedding APIs. :param model: The model identifier for embedding API calls. :param endpoint: API endpoint URL. Defaults to '/embeddings'. :param response_path: Path to extract embeddings from response. Defaults to 'data.0.embedding'. :param kwargs: Configuration for the OpenAI client. """ self.model = model self.endpoint = endpoint or "/embeddings" self.response_path = response_path or "data.0.embedding" client_args = filter_arguments(openai.OpenAI, kwargs) self._client = openai.OpenAI(**client_args)
def __call__(self, input, **kwargs): """ Processes input text and returns embeddings. :param input: Input text or list of texts to embed. :param kwargs: Additional API parameters. :return: Extracted embeddings or empty list on error. """ body = { "model": self.model, "input": input, } body.update(kwargs) try: response = self._client.post(self.endpoint, body=body, cast_to=httpx.Response) result = response.json() return nested_access(result, self.response_path) or [] except Exception as e: logger.exception(f"Embedding API error: {e}") return []
[docs] def prepare_api_model( model, *, endpoint=None, response_path=None, return_processor=False, processor_config=None, **model_params ): """Creates a callable API model for interacting with OpenAI-compatible API. The callable supports custom response parsing and works with proxy servers that may be incompatible. :param model: The name of the model to interact with. :param endpoint: The URL endpoint for the API. If provided as a relative path, it will be appended to the base URL (defined by the `OPENAI_BASE_URL` environment variable or through an additional `base_url` parameter). Supported endpoints include: - '/chat/completions' for chat models - '/embeddings' for embedding models Defaults to `/chat/completions` for OpenAI compatibility. :param response_path: The dot-separated path to extract desired content from the API response. Defaults to 'choices.0.message.content' for chat models and 'data.0.embedding' for embedding models. :param return_processor: A boolean flag indicating whether to return a processor along with the model. The processor can be used for tasks like tokenization or encoding. Defaults to False. :param processor_config: A dictionary containing configuration parameters for initializing a Hugging Face processor. It is only relevant if `return_processor` is set to True. :param model_params: Additional parameters for configuring the API model. :return: A callable APIModel instance, and optionally a processor if `return_processor` is True. """ endpoint = endpoint or "/chat/completions" ENDPOINT_CLASS_MAP = { "chat": ChatAPIModel, "embeddings": EmbeddingAPIModel, } API_Class = next((cls for keyword, cls in ENDPOINT_CLASS_MAP.items() if keyword in endpoint.lower()), None) if API_Class is None: raise ValueError(f"Unsupported endpoint: {endpoint}") client = API_Class(model=model, endpoint=endpoint, response_path=response_path, **model_params) if not return_processor: return client def get_processor(): try: return tiktoken.encoding_for_model(model) except Exception: pass try: return dashscope.get_tokenizer(model) except Exception: pass try: processor = transformers.AutoProcessor.from_pretrained( pretrained_model_name_or_path=model, **processor_config ) return processor except Exception: pass raise ValueError( "Failed to initialize the processor. Please check the following:\n" # noqa: E501 "- For OpenAI models: Install 'tiktoken' via `pip install tiktoken`.\n" # noqa: E501 "- For DashScope models: Install both 'dashscope' and 'tiktoken' via `pip install dashscope tiktoken`.\n" # noqa: E501 "- For custom models: Use the 'processor_config' parameter to configure a Hugging Face processor." # noqa: E501 ) if processor_config is not None and "pretrained_model_name_or_path" in processor_config: processor = transformers.AutoProcessor.from_pretrained(**processor_config) else: processor = get_processor() return (client, processor)
[docs] def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type, **model_params): """ Prepare and load an Diffusion model from HuggingFace. :param pretrained_model_name_or_path: input Diffusion model name or local path to the model :param diffusion_type: the use of the diffusion model. It can be 'image2image', 'text2image', 'inpainting' :return: a Diffusion model. """ TORCH_DTYPE_MAPPING = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, } LazyLoader.check_packages(["torch", "transformers"]) device = model_params.pop("device", None) if not device: model_params["device_map"] = "balanced" if "torch_dtype" in model_params: model_params["torch_dtype"] = TORCH_DTYPE_MAPPING[model_params["torch_dtype"]] diffusion_type_to_pipeline = { "image2image": diffusers.AutoPipelineForImage2Image, "text2image": diffusers.AutoPipelineForText2Image, "inpainting": diffusers.AutoPipelineForInpainting, } if diffusion_type not in diffusion_type_to_pipeline.keys(): raise ValueError( f"Not support {diffusion_type} diffusion_type for diffusion " "model. Can only be one of " '["image2image", "text2image", "inpainting"].' ) pipeline = diffusion_type_to_pipeline[diffusion_type] model = pipeline.from_pretrained(pretrained_model_name_or_path, **model_params) if device: model = model.to(device) return model
[docs] def prepare_fastsam_model(model_path, **model_params): device = model_params.pop("device", "cpu") model = ultralytics.FastSAM(check_model(model_path)).to(device) return model
[docs] def prepare_fasttext_model(model_name="lid.176.bin", **model_params): """ Prepare and load a fasttext model. :param model_name: input model name :return: model instance. """ logger.info("Loading fasttext language identification model...") try: # Suppress FastText warnings by redirecting stderr with redirect_stderr(io.StringIO()): ft_model = fasttext.load_model(check_model(model_name)) # Verify the model has the predict method (for language identification) if not hasattr(ft_model, "predict"): raise AttributeError("Loaded model does not support prediction") except Exception as e: logger.warning(f"Error loading model: {e}. Attempting to force download...") try: with redirect_stderr(io.StringIO()): ft_model = fasttext.load_model(check_model(model_name, force=True)) if not hasattr(ft_model, "predict"): raise AttributeError("Loaded model does not support prediction") except Exception as e: logger.error(f"Failed to load model after download attempt: {e}") raise return ft_model
[docs] def prepare_huggingface_model( pretrained_model_name_or_path, *, return_model=True, return_pipe=False, pipe_task="text-generation", **model_params ): """ Prepare and load a huggingface model. :param pretrained_model_name_or_path: model name or path :param return_model: return model or not :param return_pipe: return pipeline or not :param pipe_task: task for pipeline :return: a tuple (model, processor) if `return_model` is True; otherwise, only the processor is returned. """ # Check if we need accelerate for device_map if "device" in model_params: device = model_params.pop("device") if device.startswith("cuda"): try: model_params["device_map"] = device except ImportError: # If accelerate is not available, use device directly model_params["device"] = device logger.warning("accelerate not found, using device directly") processor = transformers.AutoProcessor.from_pretrained(pretrained_model_name_or_path, **model_params) if return_model: config = transformers.AutoConfig.from_pretrained(pretrained_model_name_or_path, **model_params) if hasattr(config, "auto_map"): class_name = next((k for k in config.auto_map if k.startswith("AutoModel")), "AutoModel") else: # TODO: What happens if more than one class_name = config.architectures[0] model_class = getattr(transformers, class_name) model = model_class.from_pretrained(pretrained_model_name_or_path, **model_params) if return_pipe: if isinstance(processor, transformers.PreTrainedTokenizerBase): pipe_params = {"tokenizer": processor} elif isinstance(processor, transformers.SequenceFeatureExtractor): pipe_params = {"feature_extractor": processor} elif isinstance(processor, transformers.BaseImageProcessor): pipe_params = {"image_processor": processor} pipe = transformers.pipeline(task=pipe_task, model=model, config=config, **pipe_params) model = pipe return (model, processor) if return_model else processor
[docs] def prepare_kenlm_model(lang, name_pattern="{}.arpa.bin", **model_params): """ Prepare and load a kenlm model. :param model_name: input model name in formatting syntax. :param lang: language to render model name :return: model instance. """ model_params.pop("device", None) model_name = name_pattern.format(lang) logger.info("Loading kenlm language model...") try: kenlm_model = kenlm.Model(check_model(model_name), **model_params) except: # noqa: E722 kenlm_model = kenlm.Model(check_model(model_name, force=True), **model_params) return kenlm_model
[docs] def prepare_nltk_model(lang, name_pattern="punkt.{}.pickle", **model_params): """ Prepare and load a nltk punkt model with enhanced resource handling. :param model_name: input model name in formatting syntax :param lang: language to render model name :return: model instance. """ model_params.pop("device", None) # Ensure pickle security is patched patch_nltk_pickle_security() nltk_to_punkt = {"en": "english", "fr": "french", "pt": "portuguese", "es": "spanish"} assert lang in nltk_to_punkt.keys(), "lang must be one of the following: {}".format(list(nltk_to_punkt.keys())) logger.info("Loading nltk punkt split model...") try: # Resource path and fallback for the punkt model resource_path = f"tokenizers/punkt/{nltk_to_punkt[lang]}.pickle" # Ensure the resource is available if ensure_nltk_resource(resource_path, "punkt"): logger.info(f"Successfully verified resource {resource_path}") else: logger.warning(f"Could not verify resource {resource_path}, model may not " f"work correctly") # Load the model nltk_model = nltk.data.load(resource_path, **model_params) except Exception as e: # Fallback to downloading and retrying logger.warning(f"Error loading model: {e}. Attempting to download...") try: nltk.download("punkt", quiet=False) nltk_model = nltk.data.load(resource_path, **model_params) except Exception as download_error: logger.error(f"Failed to load model after download " f"attempt: {download_error}") raise return nltk_model
[docs] def prepare_nltk_pos_tagger(**model_params): """ Prepare and load NLTK's part-of-speech tagger with enhanced resource handling. :return: The POS tagger model """ model_params.pop("device", None) # Ensure pickle security is patched patch_nltk_pickle_security() logger.info("Loading NLTK POS tagger model...") try: # Resource path and fallback for the averaged_perceptron_tagger resource_path = "taggers/averaged_perceptron_tagger/english.pickle" # Ensure the resource is available if ensure_nltk_resource(resource_path, "averaged_perceptron_tagger"): logger.info(f"Successfully verified resource {resource_path}") else: logger.warning(f"Could not verify resource {resource_path}, model may not " f"work correctly") # Import the POS tagger import nltk.tag tagger = nltk.tag.pos_tag except Exception as e: # Fallback to downloading and retrying logger.warning(f"Error loading POS tagger: {e}. Attempting to download...") try: nltk.download("averaged_perceptron_tagger", quiet=False) import nltk.tag tagger = nltk.tag.pos_tag except Exception as download_error: logger.error(f"Failed to load POS tagger after download " f"attempt: {download_error}") raise return tagger
[docs] def prepare_opencv_classifier(model_path, **model_params): model = cv2.CascadeClassifier(model_path) return model
[docs] def prepare_recognizeAnything_model( pretrained_model_name_or_path="ram_plus_swin_large_14m.pth", input_size=384, **model_params ): """ Prepare and load recognizeAnything model. :param model_name: input model name. :param input_size: the input size of the model. """ logger.info("Loading recognizeAnything model...") try: model = ram.models.ram_plus( pretrained=check_model(pretrained_model_name_or_path), image_size=input_size, vit="swin_l" ) except (RuntimeError, UnpicklingError) as e: # noqa: E722 logger.warning(e) model = ram.models.ram_plus( pretrained=check_model(pretrained_model_name_or_path, force=True), image_size=input_size, vit="swin_l" ) device = model_params.pop("device", "cpu") model.to(device).eval() return model
[docs] def prepare_sdxl_prompt2prompt(pretrained_model_name_or_path, pipe_func, torch_dtype="fp32", device="cpu"): if torch_dtype == "fp32": model = pipe_func.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch.float32, use_safetensors=True ).to(device) else: model = pipe_func.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch.float16, use_safetensors=True ).to(device) return model
[docs] def prepare_sentencepiece_model(model_path, **model_params): """ Prepare and load a sentencepiece model. :param model_path: input model path :return: model instance """ logger.info("Loading sentencepiece model...") sentencepiece_model = sentencepiece.SentencePieceProcessor() try: sentencepiece_model.load(check_model(model_path)) except: # noqa: E722 sentencepiece_model.load(check_model(model_path, force=True)) return sentencepiece_model
[docs] def prepare_sentencepiece_for_lang(lang, name_pattern="{}.sp.model", **model_params): """ Prepare and load a sentencepiece model for specific language. :param lang: language to render model name :param name_pattern: pattern to render the model name :return: model instance. """ model_name = name_pattern.format(lang) return prepare_sentencepiece_model(model_name)
[docs] def prepare_simple_aesthetics_model(pretrained_model_name_or_path, *, return_model=True, **model_params): """ Prepare and load a simple aesthetics model. :param pretrained_model_name_or_path: model name or path :param return_model: return model or not :return: a tuple (model, input processor) if `return_model` is True; otherwise, only the processor is returned. """ # Check if we need accelerate for device_map if "device" in model_params: device = model_params.pop("device") if device.startswith("cuda"): try: model_params["device_map"] = device except ImportError: # If accelerate is not available, use device directly model_params["device"] = device logger.warning("accelerate not found, using device directly") processor = transformers.CLIPProcessor.from_pretrained(pretrained_model_name_or_path, **model_params) if not return_model: return processor else: if "v1" in pretrained_model_name_or_path: model = aes_pred.AestheticsPredictorV1.from_pretrained(pretrained_model_name_or_path, **model_params) elif "v2" in pretrained_model_name_or_path and "linear" in pretrained_model_name_or_path: model = aes_pred.AestheticsPredictorV2Linear.from_pretrained(pretrained_model_name_or_path, **model_params) elif "v2" in pretrained_model_name_or_path and "relu" in pretrained_model_name_or_path: model = aes_pred.AestheticsPredictorV2ReLU.from_pretrained(pretrained_model_name_or_path, **model_params) else: raise ValueError("Not support {}".format(pretrained_model_name_or_path)) return (model, processor)
[docs] def prepare_spacy_model(lang, name_pattern="{}_core_web_md-3.7.0", **model_params): """ Prepare spacy model for specific language. :param lang: language of sapcy model. Should be one of ["zh", "en"] :return: corresponding spacy model """ import spacy assert lang in ["zh", "en"], "Diversity only support zh and en" model_name = name_pattern.format(lang) logger.info(f"Loading spacy model [{model_name}]...") compressed_model = "{}.tar.gz".format(model_name) # decompress the compressed model if it's not decompressed def decompress_model(compressed_model_path): if not compressed_model_path.endswith(".tar.gz"): raise ValueError("Only .tar.gz files are supported") decompressed_model_path = compressed_model_path.replace(".tar.gz", "") if os.path.exists(decompressed_model_path) and os.path.isdir(decompressed_model_path): return decompressed_model_path ver_name = os.path.basename(decompressed_model_path) unver_name = ver_name.rsplit("-", maxsplit=1)[0] target_dir_in_archive = f"{ver_name}/{unver_name}/{ver_name}/" import tarfile with tarfile.open(compressed_model_path, "r:gz") as tar: for member in tar.getmembers(): if member.name.startswith(target_dir_in_archive): # relative path without unnecessary directory levels relative_path = os.path.relpath(member.name, start=target_dir_in_archive) target_path = os.path.join(decompressed_model_path, relative_path) if member.isfile(): # ensure the directory exists target_directory = os.path.dirname(target_path) os.makedirs(target_directory, exist_ok=True) # for files, extract to the specific location with tar.extractfile(member) as source: with open(target_path, "wb") as target: target.write(source.read()) return decompressed_model_path try: diversity_model = spacy.load(decompress_model(check_model(compressed_model))) except: # noqa: E722 diversity_model = spacy.load(decompress_model(check_model(compressed_model, force=True))) return diversity_model
[docs] def prepare_video_blip_model(pretrained_model_name_or_path, *, return_model=True, **model_params): """ Prepare and load a video-clip model with the corresponding processor. :param pretrained_model_name_or_path: model name or path :param return_model: return model or not :param trust_remote_code: passed to transformers :return: a tuple (model, input processor) if `return_model` is True; otherwise, only the processor is returned. """ if "device" in model_params: model_params["device_map"] = model_params.pop("device") class VideoBlipVisionModel(transformers.Blip2VisionModel): """A simple, augmented version of Blip2VisionModel to handle videos.""" def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> Union[tuple, transformers.modeling_outputs.BaseModelOutputWithPooling]: """Flatten `pixel_values` along the batch and time dimension, pass it through the original vision model, then unflatten it back. :param pixel_values: a tensor of shape (batch, channel, time, height, width) :returns: last_hidden_state: a tensor of shape (batch, time * seq_len, hidden_size) pooler_output: a tensor of shape (batch, time, hidden_size) hidden_states: a tuple of tensors of shape (batch, time * seq_len, hidden_size), one for the output of the embeddings + one for each layer attentions: a tuple of tensors of shape (batch, time, num_heads, seq_len, seq_len), one for each layer """ if pixel_values is None: raise ValueError("You have to specify pixel_values") batch, _, time, _, _ = pixel_values.size() # flatten along the batch and time dimension to create a # tensor of shape # (batch * time, channel, height, width) flat_pixel_values = pixel_values.permute(0, 2, 1, 3, 4).flatten(end_dim=1) vision_outputs: transformers.modeling_outputs.BaseModelOutputWithPooling = super().forward( # noqa: E501 pixel_values=flat_pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, interpolate_pos_encoding=interpolate_pos_encoding, ) # now restore the original dimensions # vision_outputs.last_hidden_state is of shape # (batch * time, seq_len, hidden_size) seq_len = vision_outputs.last_hidden_state.size(1) last_hidden_state = vision_outputs.last_hidden_state.view(batch, time * seq_len, -1) # vision_outputs.pooler_output is of shape # (batch * time, hidden_size) pooler_output = vision_outputs.pooler_output.view(batch, time, -1) # hidden_states is a tuple of tensors of shape # (batch * time, seq_len, hidden_size) hidden_states = ( tuple(hidden.view(batch, time * seq_len, -1) for hidden in vision_outputs.hidden_states) if vision_outputs.hidden_states is not None else None ) # attentions is a tuple of tensors of shape # (batch * time, num_heads, seq_len, seq_len) attentions = ( tuple(hidden.view(batch, time, -1, seq_len, seq_len) for hidden in vision_outputs.attentions) if vision_outputs.attentions is not None else None ) if return_dict: return transformers.modeling_outputs.BaseModelOutputWithPooling( # noqa: E501 last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=hidden_states, attentions=attentions, ) return (last_hidden_state, pooler_output, hidden_states, attentions) class VideoBlipForConditionalGeneration(transformers.Blip2ForConditionalGeneration): def __init__(self, config: transformers.Blip2Config) -> None: # HACK: we call the grandparent super().__init__() to bypass # transformers.Blip2ForConditionalGeneration.__init__() so we can # replace self.vision_model super(transformers.Blip2ForConditionalGeneration, self).__init__(config) self.vision_model = VideoBlipVisionModel(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = transformers.Blip2QFormerModel(config.qformer_config) self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: language_model = transformers.AutoModelForCausalLM.from_config(config.text_config) else: language_model = transformers.AutoModelForSeq2SeqLM.from_config(config.text_config) # noqa: E501 self.language_model = language_model # Initialize weights and apply final processing self.post_init() processor = transformers.AutoProcessor.from_pretrained(pretrained_model_name_or_path, **model_params) if return_model: model_class = VideoBlipForConditionalGeneration model = model_class.from_pretrained(pretrained_model_name_or_path, **model_params) return (model, processor) if return_model else processor
[docs] def prepare_vllm_model(pretrained_model_name_or_path, **model_params): """ Prepare and load a HuggingFace model with the corresponding processor. :param pretrained_model_name_or_path: model name or path :param model_params: LLM initialization parameters. :return: a tuple of (model, tokenizer) """ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" if model_params.get("device", "").startswith("cuda:"): model_params["device"] = "cuda" model = vllm.LLM(model=pretrained_model_name_or_path, generation_config="auto", **model_params) tokenizer = model.get_tokenizer() return (model, tokenizer)
[docs] def prepare_embedding_model(model_path, **model_params): """ Prepare and load an embedding model using transformers. :param model_path: Path to the embedding model. :param model_params: Optional model parameters. :return: Model with encode() returning embedding list. """ logger.info("Loading embedding model using transformers...") if "device" in model_params: device = model_params.pop("device") tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = transformers.AutoModel.from_pretrained(model_path, trust_remote_code=True).to(device).eval() def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def encode(text, prompt_name=None, max_len=4096): if prompt_name: text = f"{prompt_name}: {text}" input_dict = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=max_len).to(device) with torch.no_grad(): outputs = model(**input_dict) embedding = last_token_pool(outputs.last_hidden_state, input_dict["attention_mask"]) embedding = nn.functional.normalize(embedding, p=2, dim=1) return embedding[0].tolist() return type("EmbeddingModel", (), {"encode": encode})()
[docs] def update_sampling_params(sampling_params, pretrained_model_name_or_path, enable_vllm=False): if enable_vllm: update_keys = {"max_tokens"} else: update_keys = {"max_new_tokens"} generation_config_keys = { "max_tokens": ["max_tokens", "max_new_tokens"], "max_new_tokens": ["max_tokens", "max_new_tokens"], } generation_config_thresholds = { "max_tokens": (max, 512), "max_new_tokens": (max, 512), } # try to get the generation configs from transformers import GenerationConfig try: model_generation_config = GenerationConfig.from_pretrained(pretrained_model_name_or_path).to_dict() except: # noqa: E722 logger.warning(f"No generation config found for the model " f"[{pretrained_model_name_or_path}]") model_generation_config = {} for key in update_keys: # if there is this param in the sampling_prams, compare it with the # thresholds and apply the specified updating function if key in sampling_params: logger.debug(f"Found param {key} in the input `sampling_params`.") continue # if not, try to find it in the generation_config of the model found = False for config_key in generation_config_keys[key]: if config_key in model_generation_config and model_generation_config[config_key]: sampling_params[key] = model_generation_config[config_key] found = True break if found: logger.debug(f"Found param {key} in the generation config as " f"{sampling_params[key]}.") continue # if not again, use the threshold directly _, th = generation_config_thresholds[key] sampling_params[key] = th logger.debug(f"Use the threshold {th} as the sampling param {key}.") return sampling_params
MODEL_FUNCTION_MAPPING = { "api": prepare_api_model, "diffusion": prepare_diffusion_model, "fasttext": prepare_fasttext_model, "fastsam": prepare_fastsam_model, "huggingface": prepare_huggingface_model, "kenlm": prepare_kenlm_model, "nltk": prepare_nltk_model, "nltk_pos_tagger": prepare_nltk_pos_tagger, "opencv_classifier": prepare_opencv_classifier, "recognizeAnything": prepare_recognizeAnything_model, "sdxl-prompt-to-prompt": prepare_sdxl_prompt2prompt, "sentencepiece": prepare_sentencepiece_for_lang, "simple_aesthetics": prepare_simple_aesthetics_model, "spacy": prepare_spacy_model, "video_blip": prepare_video_blip_model, "vllm": prepare_vllm_model, "embedding": prepare_embedding_model, } _MODELS_WITHOUT_FILE_LOCK = {"fasttext", "fastsam", "kenlm", "nltk", "recognizeAnything", "sentencepiece", "spacy"}
[docs] def prepare_model(model_type, **model_kwargs): assert model_type in MODEL_FUNCTION_MAPPING.keys(), "model_type must be one of the following: {}".format( list(MODEL_FUNCTION_MAPPING.keys()) ) model_func = MODEL_FUNCTION_MAPPING[model_type] model_key = partial(model_func, **model_kwargs) if model_type in _MODELS_WITHOUT_FILE_LOCK: # initialize once in the main process to safely download model files model_key() return model_key
[docs] def get_model(model_key=None, rank=None, use_cuda=False): if model_key is None: return None global MODEL_ZOO if model_key not in MODEL_ZOO: logger.debug(f"{model_key} not found in MODEL_ZOO ({mp.current_process().name})") if use_cuda: rank = rank if rank is not None else 0 rank = rank % cuda_device_count() device = f"cuda:{rank}" else: device = "cpu" MODEL_ZOO[model_key] = model_key(device=device) return MODEL_ZOO[model_key]
[docs] def free_models(clear_model_zoo=True): global MODEL_ZOO for model_key in MODEL_ZOO: try: model = MODEL_ZOO[model_key] model.to("cpu") if clear_model_zoo: del model except Exception: pass if clear_model_zoo: MODEL_ZOO.clear() torch.cuda.empty_cache()