import fnmatch
import inspect
import io
import os
from contextlib import redirect_stderr
from functools import partial
from pickle import UnpicklingError
from typing import Optional, Union
import httpx
import multiprocess as mp
import wget
from loguru import logger
from data_juicer import cuda_device_count
from data_juicer.utils.common_utils import nested_access
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.nltk_utils import (ensure_nltk_resource,
patch_nltk_pickle_security)
from .cache_utils import DATA_JUICER_MODELS_CACHE as DJMC
torch = LazyLoader('torch')
transformers = LazyLoader('transformers')
nn = LazyLoader('torch.nn')
fasttext = LazyLoader('fasttext', 'fasttext-wheel')
sentencepiece = LazyLoader('sentencepiece')
kenlm = LazyLoader('kenlm')
nltk = LazyLoader('nltk')
aes_pred = LazyLoader('aesthetics_predictor', 'simple-aesthetics-predictor')
vllm = LazyLoader('vllm')
diffusers = LazyLoader('diffusers')
ram = LazyLoader('ram',
'git+https://github.com/xinyu1205/recognize-anything.git')
cv2 = LazyLoader('cv2', 'opencv-python')
openai = LazyLoader('openai')
ultralytics = LazyLoader('ultralytics')
tiktoken = LazyLoader('tiktoken')
dashscope = LazyLoader('dashscope')
MODEL_ZOO = {}
# Default cached models links for downloading
MODEL_LINKS = 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' \
'data_juicer/models/'
# Backup cached models links for downloading
BACKUP_MODEL_LINKS = {
# language identification model from fasttext
'lid.176.bin':
'https://dl.fbaipublicfiles.com/fasttext/supervised-models/',
# tokenizer and language model for English from sentencepiece and KenLM
'*.sp.model':
'https://huggingface.co/edugp/kenlm/resolve/main/wikipedia/',
'*.arpa.bin':
'https://huggingface.co/edugp/kenlm/resolve/main/wikipedia/',
# sentence split model from nltk punkt
'punkt.*.pickle':
'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/'
'data_juicer/models/',
# ram
'ram_plus_swin_large_14m.pth':
'http://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/data_juicer/models/'
'ram_plus_swin_large_14m.pth',
# FastSAM
'FastSAM-s.pt':
'https://github.com/ultralytics/assets/releases/download/v8.2.0/'
'FastSAM-s.pt',
'FastSAM-x.pt':
'https://github.com/ultralytics/assets/releases/download/v8.2.0/'
'FastSAM-x.pt',
}
[docs]
def get_backup_model_link(model_name):
for pattern, url in BACKUP_MODEL_LINKS.items():
if fnmatch.fnmatch(model_name, pattern):
return url
return None
[docs]
def check_model(model_name, force=False):
"""
Check whether a model exists in DATA_JUICER_MODELS_CACHE.
If exists, return its full path.
Else, download it from cached models links.
:param model_name: a specified model name
:param force: Whether to download model forcefully or not, Sometimes
the model file maybe incomplete for some reason, so need to
download again forcefully.
"""
# check for local model
if not force and os.path.exists(model_name):
return model_name
if not os.path.exists(DJMC):
os.makedirs(DJMC)
# check if the specified model exists. If it does not exist, download it
cached_model_path = os.path.join(DJMC, model_name)
if force:
if os.path.exists(cached_model_path):
os.remove(cached_model_path)
logger.info(
f'Model [{cached_model_path}] is invalid. Forcing download...')
else:
logger.info(
f'Model [{cached_model_path}] is not found. Downloading...')
try:
model_link = os.path.join(MODEL_LINKS, model_name)
wget.download(model_link, cached_model_path)
except: # noqa: E722
try:
backup_model_link = os.path.join(
get_backup_model_link(model_name), model_name)
wget.download(backup_model_link, cached_model_path)
except: # noqa: E722
logger.error(
f'Downloading model [{model_name}] error. '
f'Please retry later or download it into {DJMC} '
f'manually from {model_link} or {backup_model_link} ')
exit(1)
return cached_model_path
[docs]
def filter_arguments(func, args_dict):
"""
Filters and returns only the valid arguments for a given function
signature.
:param func: The function or callable to inspect.
:param args_dict: A dictionary of argument names and values to filter.
:return: A dictionary containing only the arguments that match the
function's signature, preserving any **kwargs if applicable.
"""
params = inspect.signature(func).parameters
filtered_args = {}
for name, param in params.items():
if param.kind == inspect.Parameter.VAR_KEYWORD:
return args_dict
if name not in {'self', 'cls'} and name in args_dict:
filtered_args[name] = args_dict[name]
return filtered_args
[docs]
class ChatAPIModel:
[docs]
def __init__(self, model, endpoint=None, response_path=None, **kwargs):
"""
Initializes an instance of the APIModel class.
:param model: The name of the model to be used for making API
calls. This should correspond to a valid model identifier
recognized by the API server.
:param endpoint: The URL endpoint for the API. If provided as a
relative path, it will be appended to the base URL (defined by the
`OPENAI_BASE_URL` environment variable or through an additional
`base_url` parameter). Defaults to '/chat/completions' for
OpenAI compatibility.
:param response_path: A dot-separated string specifying the path to
extract the desired content from the API response. The default
value is 'choices.0.message.content', which corresponds to the
typical structure of an OpenAI API response.
:param kwargs: Additional keyword arguments for configuring the
internal OpenAI client.
"""
self.model = model
self.endpoint = endpoint or '/chat/completions'
self.response_path = response_path or 'choices.0.message.content'
client_args = filter_arguments(openai.OpenAI, kwargs)
self._client = openai.OpenAI(**client_args)
def __call__(self, messages, **kwargs):
"""
Sends messages to the configured API model and returns the parsed
response content.
:param messages: A list of message dictionaries to send to the API.
Each message should have a 'role' (e.g., 'user',
'assistant') and 'content' (the message text).
:param kwargs: Additional parameters for the API call.
:return: The parsed response content from the API call, or an empty
string if an error occurs.
"""
body = {
'messages': messages,
'model': self.model,
}
body.update(kwargs)
stream = kwargs.get('stream', False)
stream_cls = openai.Stream[openai.types.chat.ChatCompletionChunk]
try:
response = self._client.post(self.endpoint,
body=body,
cast_to=httpx.Response,
stream=stream,
stream_cls=stream_cls)
result = response.json()
return nested_access(result, self.response_path)
except Exception as e:
logger.exception(e)
return ''
[docs]
class EmbeddingAPIModel:
[docs]
def __init__(self, model, endpoint=None, response_path=None, **kwargs):
"""
Initializes an instance specialized for embedding APIs.
:param model: The model identifier for embedding API calls.
:param endpoint: API endpoint URL. Defaults to '/embeddings'.
:param response_path: Path to extract embeddings from response.
Defaults to 'data.0.embedding'.
:param kwargs: Configuration for the OpenAI client.
"""
self.model = model
self.endpoint = endpoint or '/embeddings'
self.response_path = response_path or 'data.0.embedding'
client_args = filter_arguments(openai.OpenAI, kwargs)
self._client = openai.OpenAI(**client_args)
def __call__(self, input, **kwargs):
"""
Processes input text and returns embeddings.
:param input: Input text or list of texts to embed.
:param kwargs: Additional API parameters.
:return: Extracted embeddings or empty list on error.
"""
body = {
'model': self.model,
'input': input,
}
body.update(kwargs)
try:
response = self._client.post(self.endpoint,
body=body,
cast_to=httpx.Response)
result = response.json()
return nested_access(result, self.response_path) or []
except Exception as e:
logger.exception(f'Embedding API error: {e}')
return []
[docs]
def prepare_api_model(model,
*,
endpoint=None,
response_path=None,
return_processor=False,
processor_config=None,
**model_params):
"""Creates a callable API model for interacting with OpenAI-compatible API.
The callable supports custom response parsing and works with proxy servers
that may be incompatible.
:param model: The name of the model to interact with.
:param endpoint: The URL endpoint for the API. If provided as a relative
path, it will be appended to the base URL (defined by the
`OPENAI_BASE_URL` environment variable or through an additional
`base_url` parameter). Supported endpoints include:
- '/chat/completions' for chat models
- '/embeddings' for embedding models
Defaults to `/chat/completions` for OpenAI compatibility.
:param response_path: The dot-separated path to extract desired content
from the API response. Defaults to 'choices.0.message.content'
for chat models and 'data.0.embedding' for embedding models.
:param return_processor: A boolean flag indicating whether to return a
processor along with the model. The processor can be used for tasks
like tokenization or encoding. Defaults to False.
:param processor_config: A dictionary containing configuration parameters
for initializing a Hugging Face processor. It is only relevant if
`return_processor` is set to True.
:param model_params: Additional parameters for configuring the API model.
:return: A callable APIModel instance, and optionally a processor
if `return_processor` is True.
"""
endpoint = endpoint or '/chat/completions'
ENDPOINT_CLASS_MAP = {
'chat': ChatAPIModel,
'embeddings': EmbeddingAPIModel,
}
API_Class = next((cls for keyword, cls in ENDPOINT_CLASS_MAP.items()
if keyword in endpoint.lower()), None)
if API_Class is None:
raise ValueError(f'Unsupported endpoint: {endpoint}')
client = API_Class(model=model,
endpoint=endpoint,
response_path=response_path,
**model_params)
if not return_processor:
return client
def get_processor():
try:
return tiktoken.encoding_for_model(model)
except Exception:
pass
try:
return dashscope.get_tokenizer(model)
except Exception:
pass
try:
processor = transformers.AutoProcessor.from_pretrained(
pretrained_model_name_or_path=model, **processor_config)
return processor
except Exception:
pass
raise ValueError(
'Failed to initialize the processor. Please check the following:\n' # noqa: E501
"- For OpenAI models: Install 'tiktoken' via `pip install tiktoken`.\n" # noqa: E501
"- For DashScope models: Install both 'dashscope' and 'tiktoken' via `pip install dashscope tiktoken`.\n" # noqa: E501
"- For custom models: Use the 'processor_config' parameter to configure a Hugging Face processor." # noqa: E501
)
if processor_config is not None and \
'pretrained_model_name_or_path' in processor_config:
processor = transformers.AutoProcessor.from_pretrained(
**processor_config)
else:
processor = get_processor()
return (client, processor)
[docs]
def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type,
**model_params):
"""
Prepare and load an Diffusion model from HuggingFace.
:param pretrained_model_name_or_path: input Diffusion model name
or local path to the model
:param diffusion_type: the use of the diffusion model. It can be
'image2image', 'text2image', 'inpainting'
:return: a Diffusion model.
"""
TORCH_DTYPE_MAPPING = {
'fp32': torch.float32,
'fp16': torch.float16,
'bf16': torch.bfloat16,
}
LazyLoader.check_packages(['torch', 'transformers'])
device = model_params.pop('device', None)
if not device:
model_params['device_map'] = 'balanced'
if 'torch_dtype' in model_params:
model_params['torch_dtype'] = TORCH_DTYPE_MAPPING[
model_params['torch_dtype']]
diffusion_type_to_pipeline = {
'image2image': diffusers.AutoPipelineForImage2Image,
'text2image': diffusers.AutoPipelineForText2Image,
'inpainting': diffusers.AutoPipelineForInpainting
}
if diffusion_type not in diffusion_type_to_pipeline.keys():
raise ValueError(
f'Not support {diffusion_type} diffusion_type for diffusion '
'model. Can only be one of '
'["image2image", "text2image", "inpainting"].')
pipeline = diffusion_type_to_pipeline[diffusion_type]
model = pipeline.from_pretrained(pretrained_model_name_or_path,
**model_params)
if device:
model = model.to(device)
return model
[docs]
def prepare_fastsam_model(model_path, **model_params):
device = model_params.pop('device', 'cpu')
model = ultralytics.FastSAM(check_model(model_path)).to(device)
return model
[docs]
def prepare_fasttext_model(model_name='lid.176.bin', **model_params):
"""
Prepare and load a fasttext model.
:param model_name: input model name
:return: model instance.
"""
logger.info('Loading fasttext language identification model...')
try:
# Suppress FastText warnings by redirecting stderr
with redirect_stderr(io.StringIO()):
ft_model = fasttext.load_model(check_model(model_name))
# Verify the model has the predict method (for language identification)
if not hasattr(ft_model, 'predict'):
raise AttributeError('Loaded model does not support prediction')
except Exception as e:
logger.warning(
f'Error loading model: {e}. Attempting to force download...')
try:
with redirect_stderr(io.StringIO()):
ft_model = fasttext.load_model(
check_model(model_name, force=True))
if not hasattr(ft_model, 'predict'):
raise AttributeError(
'Loaded model does not support prediction')
except Exception as e:
logger.error(f'Failed to load model after download attempt: {e}')
raise
return ft_model
[docs]
def prepare_huggingface_model(pretrained_model_name_or_path,
*,
return_model=True,
return_pipe=False,
pipe_task='text-generation',
**model_params):
"""
Prepare and load a huggingface model.
:param pretrained_model_name_or_path: model name or path
:param return_model: return model or not
:param return_pipe: return pipeline or not
:param pipe_task: task for pipeline
:return: a tuple (model, processor) if `return_model` is True;
otherwise, only the processor is returned.
"""
# Check if we need accelerate for device_map
if 'device' in model_params:
device = model_params.pop('device')
if device.startswith('cuda'):
try:
model_params['device_map'] = device
except ImportError:
# If accelerate is not available, use device directly
model_params['device'] = device
logger.warning('accelerate not found, using device directly')
processor = transformers.AutoProcessor.from_pretrained(
pretrained_model_name_or_path, **model_params)
if return_model:
config = transformers.AutoConfig.from_pretrained(
pretrained_model_name_or_path, **model_params)
if hasattr(config, 'auto_map'):
class_name = next(
(k for k in config.auto_map if k.startswith('AutoModel')),
'AutoModel')
else:
# TODO: What happens if more than one
class_name = config.architectures[0]
model_class = getattr(transformers, class_name)
model = model_class.from_pretrained(pretrained_model_name_or_path,
**model_params)
if return_pipe:
if isinstance(processor, transformers.PreTrainedTokenizerBase):
pipe_params = {'tokenizer': processor}
elif isinstance(processor, transformers.SequenceFeatureExtractor):
pipe_params = {'feature_extractor': processor}
elif isinstance(processor, transformers.BaseImageProcessor):
pipe_params = {'image_processor': processor}
pipe = transformers.pipeline(task=pipe_task,
model=model,
config=config,
**pipe_params)
model = pipe
return (model, processor) if return_model else processor
[docs]
def prepare_kenlm_model(lang, name_pattern='{}.arpa.bin', **model_params):
"""
Prepare and load a kenlm model.
:param model_name: input model name in formatting syntax.
:param lang: language to render model name
:return: model instance.
"""
model_params.pop('device', None)
model_name = name_pattern.format(lang)
logger.info('Loading kenlm language model...')
try:
kenlm_model = kenlm.Model(check_model(model_name), **model_params)
except: # noqa: E722
kenlm_model = kenlm.Model(check_model(model_name, force=True),
**model_params)
return kenlm_model
[docs]
def prepare_nltk_model(lang, name_pattern='punkt.{}.pickle', **model_params):
"""
Prepare and load a nltk punkt model with enhanced resource handling.
:param model_name: input model name in formatting syntax
:param lang: language to render model name
:return: model instance.
"""
model_params.pop('device', None)
# Ensure pickle security is patched
patch_nltk_pickle_security()
nltk_to_punkt = {
'en': 'english',
'fr': 'french',
'pt': 'portuguese',
'es': 'spanish'
}
assert lang in nltk_to_punkt.keys(
), 'lang must be one of the following: {}'.format(
list(nltk_to_punkt.keys()))
logger.info('Loading nltk punkt split model...')
try:
# Resource path and fallback for the punkt model
resource_path = f'tokenizers/punkt/{nltk_to_punkt[lang]}.pickle'
# Ensure the resource is available
if ensure_nltk_resource(resource_path, 'punkt'):
logger.info(f'Successfully verified resource {resource_path}')
else:
logger.warning(
f'Could not verify resource {resource_path}, model may not '
f'work correctly')
# Load the model
nltk_model = nltk.data.load(resource_path, **model_params)
except Exception as e:
# Fallback to downloading and retrying
logger.warning(f'Error loading model: {e}. Attempting to download...')
try:
nltk.download('punkt', quiet=False)
nltk_model = nltk.data.load(resource_path, **model_params)
except Exception as download_error:
logger.error(f'Failed to load model after download '
f'attempt: {download_error}')
raise
return nltk_model
[docs]
def prepare_nltk_pos_tagger(**model_params):
"""
Prepare and load NLTK's part-of-speech tagger with enhanced resource
handling.
:return: The POS tagger model
"""
model_params.pop('device', None)
# Ensure pickle security is patched
patch_nltk_pickle_security()
logger.info('Loading NLTK POS tagger model...')
try:
# Resource path and fallback for the averaged_perceptron_tagger
resource_path = 'taggers/averaged_perceptron_tagger/english.pickle'
# Ensure the resource is available
if ensure_nltk_resource(resource_path, 'averaged_perceptron_tagger'):
logger.info(f'Successfully verified resource {resource_path}')
else:
logger.warning(
f'Could not verify resource {resource_path}, model may not '
f'work correctly')
# Import the POS tagger
import nltk.tag
tagger = nltk.tag.pos_tag
except Exception as e:
# Fallback to downloading and retrying
logger.warning(
f'Error loading POS tagger: {e}. Attempting to download...')
try:
nltk.download('averaged_perceptron_tagger', quiet=False)
import nltk.tag
tagger = nltk.tag.pos_tag
except Exception as download_error:
logger.error(f'Failed to load POS tagger after download '
f'attempt: {download_error}')
raise
return tagger
[docs]
def prepare_opencv_classifier(model_path, **model_params):
model = cv2.CascadeClassifier(model_path)
return model
[docs]
def prepare_recognizeAnything_model(
pretrained_model_name_or_path='ram_plus_swin_large_14m.pth',
input_size=384,
**model_params):
"""
Prepare and load recognizeAnything model.
:param model_name: input model name.
:param input_size: the input size of the model.
"""
logger.info('Loading recognizeAnything model...')
try:
model = ram.models.ram_plus(
pretrained=check_model(pretrained_model_name_or_path),
image_size=input_size,
vit='swin_l')
except (RuntimeError, UnpicklingError) as e: # noqa: E722
logger.warning(e)
model = ram.models.ram_plus(pretrained=check_model(
pretrained_model_name_or_path, force=True),
image_size=input_size,
vit='swin_l')
device = model_params.pop('device', 'cpu')
model.to(device).eval()
return model
[docs]
def prepare_sdxl_prompt2prompt(pretrained_model_name_or_path,
pipe_func,
torch_dtype='fp32',
device='cpu'):
if torch_dtype == 'fp32':
model = pipe_func.from_pretrained(pretrained_model_name_or_path,
torch_dtype=torch.float32,
use_safetensors=True).to(device)
else:
model = pipe_func.from_pretrained(pretrained_model_name_or_path,
torch_dtype=torch.float16,
use_safetensors=True).to(device)
return model
[docs]
def prepare_sentencepiece_model(model_path, **model_params):
"""
Prepare and load a sentencepiece model.
:param model_path: input model path
:return: model instance
"""
logger.info('Loading sentencepiece model...')
sentencepiece_model = sentencepiece.SentencePieceProcessor()
try:
sentencepiece_model.load(check_model(model_path))
except: # noqa: E722
sentencepiece_model.load(check_model(model_path, force=True))
return sentencepiece_model
[docs]
def prepare_sentencepiece_for_lang(lang,
name_pattern='{}.sp.model',
**model_params):
"""
Prepare and load a sentencepiece model for specific language.
:param lang: language to render model name
:param name_pattern: pattern to render the model name
:return: model instance.
"""
model_name = name_pattern.format(lang)
return prepare_sentencepiece_model(model_name)
[docs]
def prepare_simple_aesthetics_model(pretrained_model_name_or_path,
*,
return_model=True,
**model_params):
"""
Prepare and load a simple aesthetics model.
:param pretrained_model_name_or_path: model name or path
:param return_model: return model or not
:return: a tuple (model, input processor) if `return_model` is True;
otherwise, only the processor is returned.
"""
# Check if we need accelerate for device_map
if 'device' in model_params:
device = model_params.pop('device')
if device.startswith('cuda'):
try:
model_params['device_map'] = device
except ImportError:
# If accelerate is not available, use device directly
model_params['device'] = device
logger.warning('accelerate not found, using device directly')
processor = transformers.CLIPProcessor.from_pretrained(
pretrained_model_name_or_path, **model_params)
if not return_model:
return processor
else:
if 'v1' in pretrained_model_name_or_path:
model = aes_pred.AestheticsPredictorV1.from_pretrained(
pretrained_model_name_or_path, **model_params)
elif ('v2' in pretrained_model_name_or_path
and 'linear' in pretrained_model_name_or_path):
model = aes_pred.AestheticsPredictorV2Linear.from_pretrained(
pretrained_model_name_or_path, **model_params)
elif ('v2' in pretrained_model_name_or_path
and 'relu' in pretrained_model_name_or_path):
model = aes_pred.AestheticsPredictorV2ReLU.from_pretrained(
pretrained_model_name_or_path, **model_params)
else:
raise ValueError(
'Not support {}'.format(pretrained_model_name_or_path))
return (model, processor)
[docs]
def prepare_spacy_model(lang,
name_pattern='{}_core_web_md-3.7.0',
**model_params):
"""
Prepare spacy model for specific language.
:param lang: language of sapcy model. Should be one of ["zh",
"en"]
:return: corresponding spacy model
"""
import spacy
assert lang in ['zh', 'en'], 'Diversity only support zh and en'
model_name = name_pattern.format(lang)
logger.info(f'Loading spacy model [{model_name}]...')
compressed_model = '{}.tar.gz'.format(model_name)
# decompress the compressed model if it's not decompressed
def decompress_model(compressed_model_path):
if not compressed_model_path.endswith('.tar.gz'):
raise ValueError('Only .tar.gz files are supported')
decompressed_model_path = compressed_model_path.replace('.tar.gz', '')
if os.path.exists(decompressed_model_path) \
and os.path.isdir(decompressed_model_path):
return decompressed_model_path
ver_name = os.path.basename(decompressed_model_path)
unver_name = ver_name.rsplit('-', maxsplit=1)[0]
target_dir_in_archive = f'{ver_name}/{unver_name}/{ver_name}/'
import tarfile
with tarfile.open(compressed_model_path, 'r:gz') as tar:
for member in tar.getmembers():
if member.name.startswith(target_dir_in_archive):
# relative path without unnecessary directory levels
relative_path = os.path.relpath(
member.name, start=target_dir_in_archive)
target_path = os.path.join(decompressed_model_path,
relative_path)
if member.isfile():
# ensure the directory exists
target_directory = os.path.dirname(target_path)
os.makedirs(target_directory, exist_ok=True)
# for files, extract to the specific location
with tar.extractfile(member) as source:
with open(target_path, 'wb') as target:
target.write(source.read())
return decompressed_model_path
try:
diversity_model = spacy.load(
decompress_model(check_model(compressed_model)))
except: # noqa: E722
diversity_model = spacy.load(
decompress_model(check_model(compressed_model, force=True)))
return diversity_model
[docs]
def prepare_video_blip_model(pretrained_model_name_or_path,
*,
return_model=True,
**model_params):
"""
Prepare and load a video-clip model with the corresponding processor.
:param pretrained_model_name_or_path: model name or path
:param return_model: return model or not
:param trust_remote_code: passed to transformers
:return: a tuple (model, input processor) if `return_model` is True;
otherwise, only the processor is returned.
"""
if 'device' in model_params:
model_params['device_map'] = model_params.pop('device')
class VideoBlipVisionModel(transformers.Blip2VisionModel):
"""A simple, augmented version of Blip2VisionModel to handle
videos."""
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple,
transformers.modeling_outputs.BaseModelOutputWithPooling]:
"""Flatten `pixel_values` along the batch and time dimension,
pass it through the original vision model,
then unflatten it back.
:param pixel_values: a tensor of shape
(batch, channel, time, height, width)
:returns:
last_hidden_state: a tensor of shape
(batch, time * seq_len, hidden_size)
pooler_output: a tensor of shape
(batch, time, hidden_size)
hidden_states:
a tuple of tensors of shape
(batch, time * seq_len, hidden_size),
one for the output of the embeddings +
one for each layer
attentions:
a tuple of tensors of shape
(batch, time, num_heads, seq_len, seq_len),
one for each layer
"""
if pixel_values is None:
raise ValueError('You have to specify pixel_values')
batch, _, time, _, _ = pixel_values.size()
# flatten along the batch and time dimension to create a
# tensor of shape
# (batch * time, channel, height, width)
flat_pixel_values = pixel_values.permute(0, 2, 1, 3,
4).flatten(end_dim=1)
vision_outputs: transformers.modeling_outputs.BaseModelOutputWithPooling = super( # noqa: E501
).forward(
pixel_values=flat_pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
interpolate_pos_encoding=interpolate_pos_encoding,
)
# now restore the original dimensions
# vision_outputs.last_hidden_state is of shape
# (batch * time, seq_len, hidden_size)
seq_len = vision_outputs.last_hidden_state.size(1)
last_hidden_state = vision_outputs.last_hidden_state.view(
batch, time * seq_len, -1)
# vision_outputs.pooler_output is of shape
# (batch * time, hidden_size)
pooler_output = vision_outputs.pooler_output.view(batch, time, -1)
# hidden_states is a tuple of tensors of shape
# (batch * time, seq_len, hidden_size)
hidden_states = (tuple(
hidden.view(batch, time * seq_len, -1)
for hidden in vision_outputs.hidden_states)
if vision_outputs.hidden_states is not None else
None)
# attentions is a tuple of tensors of shape
# (batch * time, num_heads, seq_len, seq_len)
attentions = (tuple(
hidden.view(batch, time, -1, seq_len, seq_len)
for hidden in vision_outputs.attentions)
if vision_outputs.attentions is not None else None)
if return_dict:
return transformers.modeling_outputs.BaseModelOutputWithPooling( # noqa: E501
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=hidden_states,
attentions=attentions,
)
return (last_hidden_state, pooler_output, hidden_states,
attentions)
class VideoBlipForConditionalGeneration(
transformers.Blip2ForConditionalGeneration):
def __init__(self, config: transformers.Blip2Config) -> None:
# HACK: we call the grandparent super().__init__() to bypass
# transformers.Blip2ForConditionalGeneration.__init__() so we can
# replace self.vision_model
super(transformers.Blip2ForConditionalGeneration,
self).__init__(config)
self.vision_model = VideoBlipVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens,
config.qformer_config.hidden_size))
self.qformer = transformers.Blip2QFormerModel(
config.qformer_config)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = transformers.AutoModelForCausalLM.from_config(
config.text_config)
else:
language_model = transformers.AutoModelForSeq2SeqLM.from_config( # noqa: E501
config.text_config)
self.language_model = language_model
# Initialize weights and apply final processing
self.post_init()
processor = transformers.AutoProcessor.from_pretrained(
pretrained_model_name_or_path, **model_params)
if return_model:
model_class = VideoBlipForConditionalGeneration
model = model_class.from_pretrained(pretrained_model_name_or_path,
**model_params)
return (model, processor) if return_model else processor
[docs]
def prepare_vllm_model(pretrained_model_name_or_path, **model_params):
"""
Prepare and load a HuggingFace model with the corresponding processor.
:param pretrained_model_name_or_path: model name or path
:param model_params: LLM initialization parameters.
:return: a tuple of (model, tokenizer)
"""
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if model_params.get('device', '').startswith('cuda:'):
model_params['device'] = 'cuda'
model = vllm.LLM(model=pretrained_model_name_or_path,
generation_config='auto',
**model_params)
tokenizer = model.get_tokenizer()
return (model, tokenizer)
[docs]
def prepare_embedding_model(model_path, **model_params):
"""
Prepare and load an embedding model using transformers.
:param model_path: Path to the embedding model.
:param model_params: Optional model parameters.
:return: Model with encode() returning embedding list.
"""
logger.info('Loading embedding model using transformers...')
if 'device' in model_params:
device = model_params.pop('device')
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True)
model = transformers.AutoModel.from_pretrained(
model_path, trust_remote_code=True).to(device).eval()
def last_token_pool(last_hidden_states: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[
torch.arange(batch_size, device=last_hidden_states.device),
sequence_lengths]
def encode(text, prompt_name=None, max_len=4096):
if prompt_name:
text = f'{prompt_name}: {text}'
input_dict = tokenizer(text,
padding=True,
truncation=True,
return_tensors='pt',
max_length=max_len).to(device)
with torch.no_grad():
outputs = model(**input_dict)
embedding = last_token_pool(outputs.last_hidden_state,
input_dict['attention_mask'])
embedding = nn.functional.normalize(embedding, p=2, dim=1)
return embedding[0].tolist()
return type('EmbeddingModel', (), {'encode': encode})()
[docs]
def update_sampling_params(sampling_params,
pretrained_model_name_or_path,
enable_vllm=False):
if enable_vllm:
update_keys = {'max_tokens'}
else:
update_keys = {'max_new_tokens'}
generation_config_keys = {
'max_tokens': ['max_tokens', 'max_new_tokens'],
'max_new_tokens': ['max_tokens', 'max_new_tokens'],
}
generation_config_thresholds = {
'max_tokens': (max, 512),
'max_new_tokens': (max, 512),
}
# try to get the generation configs
from transformers import GenerationConfig
try:
model_generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path).to_dict()
except: # noqa: E722
logger.warning(f'No generation config found for the model '
f'[{pretrained_model_name_or_path}]')
model_generation_config = {}
for key in update_keys:
# if there is this param in the sampling_prams, compare it with the
# thresholds and apply the specified updating function
if key in sampling_params:
logger.debug(f'Found param {key} in the input `sampling_params`.')
continue
# if not, try to find it in the generation_config of the model
found = False
for config_key in generation_config_keys[key]:
if config_key in model_generation_config \
and model_generation_config[config_key]:
sampling_params[key] = model_generation_config[config_key]
found = True
break
if found:
logger.debug(f'Found param {key} in the generation config as '
f'{sampling_params[key]}.')
continue
# if not again, use the threshold directly
_, th = generation_config_thresholds[key]
sampling_params[key] = th
logger.debug(f'Use the threshold {th} as the sampling param {key}.')
return sampling_params
MODEL_FUNCTION_MAPPING = {
'api': prepare_api_model,
'diffusion': prepare_diffusion_model,
'fasttext': prepare_fasttext_model,
'fastsam': prepare_fastsam_model,
'huggingface': prepare_huggingface_model,
'kenlm': prepare_kenlm_model,
'nltk': prepare_nltk_model,
'nltk_pos_tagger': prepare_nltk_pos_tagger,
'opencv_classifier': prepare_opencv_classifier,
'recognizeAnything': prepare_recognizeAnything_model,
'sdxl-prompt-to-prompt': prepare_sdxl_prompt2prompt,
'sentencepiece': prepare_sentencepiece_for_lang,
'simple_aesthetics': prepare_simple_aesthetics_model,
'spacy': prepare_spacy_model,
'video_blip': prepare_video_blip_model,
'vllm': prepare_vllm_model,
'embedding': prepare_embedding_model,
}
_MODELS_WITHOUT_FILE_LOCK = {
'fasttext', 'fastsam', 'kenlm', 'nltk', 'recognizeAnything',
'sentencepiece', 'spacy'
}
[docs]
def prepare_model(model_type, **model_kwargs):
assert (model_type in MODEL_FUNCTION_MAPPING.keys()
), 'model_type must be one of the following: {}'.format(
list(MODEL_FUNCTION_MAPPING.keys()))
model_func = MODEL_FUNCTION_MAPPING[model_type]
model_key = partial(model_func, **model_kwargs)
if model_type in _MODELS_WITHOUT_FILE_LOCK:
# initialize once in the main process to safely download model files
model_key()
return model_key
[docs]
def get_model(model_key=None, rank=None, use_cuda=False):
if model_key is None:
return None
global MODEL_ZOO
if model_key not in MODEL_ZOO:
logger.debug(
f'{model_key} not found in MODEL_ZOO ({mp.current_process().name})'
)
if use_cuda:
rank = rank if rank is not None else 0
rank = rank % cuda_device_count()
device = f'cuda:{rank}'
else:
device = 'cpu'
MODEL_ZOO[model_key] = model_key(device=device)
return MODEL_ZOO[model_key]
[docs]
def free_models(clear_model_zoo=True):
global MODEL_ZOO
for model_key in MODEL_ZOO:
try:
MODEL_ZOO[model_key].to('cpu')
except Exception:
pass
if clear_model_zoo:
MODEL_ZOO.clear()