import copy
import os
from typing import Optional
from PIL import Image
from pydantic import Field, PositiveInt
from typing_extensions import Annotated
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import transfer_filename
from data_juicer.utils.mm_utils import (SpecialTokens, load_data_with_context,
load_image, remove_special_tokens)
from data_juicer.utils.model_utils import get_model, prepare_model
from ..base_op import OPERATORS, Mapper
from ..op_fusion import LOADED_IMAGES
OP_NAME = 'image_diffusion_mapper'
[docs]@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageDiffusionMapper(Mapper):
"""
Generate image by diffusion model
"""
_accelerator = 'cuda'
_batched_op = True
[docs] def __init__(self,
hf_diffusion: str = 'CompVis/stable-diffusion-v1-4',
trust_remote_code: bool = False,
torch_dtype: str = 'fp32',
revision: str = 'main',
strength: Annotated[float, Field(ge=0, le=1)] = 0.8,
guidance_scale: float = 7.5,
aug_num: PositiveInt = 1,
keep_original_sample: bool = True,
caption_key: Optional[str] = None,
hf_img2seq: str = 'Salesforce/blip2-opt-2.7b',
*args,
**kwargs):
"""
Initialization method.
:param hf_diffusion: diffusion model name on huggingface to generate
the image.
:param torch_dtype: the floating point type used to load the diffusion
model. Can be one of ['fp32', 'fp16', 'bf16']
:param revision: The specific model version to use. It can be a
branch name, a tag name, a commit id, or any identifier allowed
by Git.
:param strength: Indicates extent to transform the reference image.
Must be between 0 and 1. image is used as a starting point and
more noise is added the higher the strength. The number of
denoising steps depends on the amount of noise initially added.
When strength is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in
num_inference_steps. A value of 1 essentially ignores image.
:param guidance_scale: A higher guidance scale value encourages the
model to generate images closely linked to the text prompt at the
expense of lower image quality. Guidance scale is enabled when
guidance_scale > 1.
:param aug_num: The image number to be produced by stable-diffusion
model.
:param keep_candidate_mode: retain strategy for the generated
$caption_num$ candidates.
'random_any': Retain the random one from generated captions
'similar_one_simhash': Retain the generated one that is most
similar to the original caption
'all': Retain all generated captions by concatenation
Note:
This is a batched_OP, whose input and output type are
both list. Suppose there are $N$ list of input samples, whose batch
size is $b$, and denote caption_num as $M$.
The number of total samples after generation is $2Nb$ when
keep_original_sample is True and $Nb$ when keep_original_sample is
False. For 'random_any' and 'similar_one_simhash' mode,
it's $(1+M)Nb$ for 'all' mode when keep_original_sample is True
and $MNb$ when keep_original_sample is False.
:param caption_key: the key name of fields in samples to store captions
for each images. It can be a string if there is only one image in
each sample. Otherwise, it should be a list. If it's none,
ImageDiffusionMapper will produce captions for each images.
:param hf_img2seq: model name on huggingface to generate caption if
caption_key is None.
"""
super().__init__(*args, **kwargs)
self._init_parameters = self.remove_extra_parameters(locals())
self.strength = strength
self.guidance_scale = guidance_scale
self.aug_num = aug_num
self.keep_original_sample = keep_original_sample
self.caption_key = caption_key
self.prompt = 'A photo of a '
if not self.caption_key:
from .image_captioning_mapper import ImageCaptioningMapper
self.op_generate_caption = ImageCaptioningMapper(
hf_img2seq=hf_img2seq,
keep_original_sample=False,
prompt=self.prompt)
self.model_key = prepare_model(
model_type='diffusion',
pretrained_model_name_or_path=hf_diffusion,
diffusion_type='image2image',
torch_dtype=torch_dtype,
revision=revision,
trust_remote_code=trust_remote_code)
def _real_guidance(self, caption: str, image: Image.Image, rank=None):
canvas = image.resize((512, 512), Image.BILINEAR)
prompt = caption
diffusion_model = get_model(model_key=self.model_key,
rank=rank,
use_cuda=self.use_cuda())
kwargs = dict(image=canvas,
prompt=[prompt],
strength=self.strength,
guidance_scale=self.guidance_scale)
has_nsfw_concept = True
while has_nsfw_concept:
outputs = diffusion_model(**kwargs)
has_nsfw_concept = (diffusion_model.safety_checker is not None
and outputs.nsfw_content_detected[0])
canvas = outputs.images[0].resize(image.size, Image.BILINEAR)
return canvas
def _process_single_sample(self, ori_sample, rank=None, context=False):
"""
:param ori_sample: a single data sample before applying generation
:return: batched results after generation
"""
# there is no image in this sample
if self.image_key not in ori_sample or \
not ori_sample[self.image_key]:
return []
# load images
loaded_image_keys = ori_sample[self.image_key]
ori_sample, images = load_data_with_context(ori_sample, context,
loaded_image_keys,
load_image)
# load captions
if self.caption_key:
captions = ori_sample[self.caption_key]
if not isinstance(captions, list):
# one caption for all images
captions = [captions] * len(images)
else:
assert len(captions) == len(
images
), 'The num of captions must match the num of images.'
captions = [remove_special_tokens(c) for c in captions]
else:
caption_samples = {
self.text_key: [SpecialTokens.image] * len(images),
self.image_key: [[k] for k in loaded_image_keys]
}
caption_samples = self.op_generate_caption.process(caption_samples,
rank=rank)
captions = caption_samples[self.text_key]
captions = [
self.prompt + remove_special_tokens(c) for c in captions
]
# the generated results
generated_samples = [
copy.deepcopy(ori_sample) for _ in range(self.aug_num)
]
for aug_id in range(self.aug_num):
diffusion_image_keys = []
for index, value in enumerate(loaded_image_keys):
related_parameters = self.add_parameters(
self._init_parameters, caption=captions[index])
diffusion_image_key = transfer_filename(
value, OP_NAME, **related_parameters)
diffusion_image_keys.append(diffusion_image_key)
# TODO: duplicated generation if image is reused
if not os.path.exists(diffusion_image_key
) or diffusion_image_key not in images:
diffusion_image = self._real_guidance(captions[index],
images[value],
rank=rank)
images[diffusion_image_key] = diffusion_image
diffusion_image.save(diffusion_image_key)
if context:
generated_samples[aug_id][Fields.context][
diffusion_image_key] = diffusion_image
generated_samples[aug_id][self.image_key] = diffusion_image_keys
return generated_samples
[docs] def process_batched(self, samples, rank=None, context=False):
"""
Note:
This is a batched_OP, whose the input and output type are
both list. Suppose there are $N$ input sample list with batch
size as $b$, and denote aug_num as $M$.
the number of total samples after generation is $(1+M)Nb$.
:param samples:
:return:
"""
# reconstruct samples from "dict of lists" to "list of dicts"
reconstructed_samples = []
for i in range(len(samples[self.text_key])):
reconstructed_samples.append(
{key: samples[key][i]
for key in samples})
# do generation for each sample within the batch
samples_after_generation = []
for ori_sample in reconstructed_samples:
if self.keep_original_sample:
samples_after_generation.append(ori_sample)
generated_samples = self._process_single_sample(ori_sample,
rank=rank)
if len(generated_samples) != 0:
samples_after_generation.extend(generated_samples)
# reconstruct samples from "list of dicts" to "dict of lists"
keys = samples_after_generation[0].keys()
res_samples = {}
for key in keys:
res_samples[key] = [s[key] for s in samples_after_generation]
return res_samples