import logging
import os
import random
import time
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.ops.op_fusion import LOADED_IMAGES
from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE
from data_juicer.utils.file_utils import transfer_data_dir
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
diffusers = LazyLoader('diffusers', 'diffusers')
torch = LazyLoader('torch', 'torch')
p2p_pipeline = LazyLoader('p2p_pipeline',
'data_juicer.ops.common.prompt2prompt_pipeline',
auto_install=False)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
OP_NAME = 'sdxl_prompt2prompt_mapper'
[docs]
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class SDXLPrompt2PromptMapper(Mapper):
"""
Generate pairs of similar images by the SDXL model
"""
_accelerator = 'cuda'
[docs]
def __init__(
self,
hf_diffusion: str = 'stabilityai/stable-diffusion-xl-base-1.0',
trust_remote_code=False,
torch_dtype: str = 'fp32',
num_inference_steps: float = 50,
guidance_scale: float = 7.5,
text_key=None,
text_key_second=None,
output_dir=DATA_JUICER_ASSETS_CACHE,
*args,
**kwargs):
"""
Initialization method.
:param hf_diffusion: diffusion model name on huggingface to generate
the image.
:param torch_dtype: the floating point type used to load the diffusion
model.
:param num_inference_steps: The larger the value, the better the
image generation quality; however, this also increases the time
required for generation.
:param guidance_scale: A higher guidance scale value encourages the
model to generate images closely linked to the text prompt at the
expense of lower image quality. Guidance scale is enabled when
:param text_key: the key name used to store the first caption
in the caption pair.
:param text_key_second: the key name used to store the second caption
in the caption pair.
:param output_dir: the storage location of the generated images.
"""
kwargs.setdefault('mem_required', '38GB')
super().__init__(*args, **kwargs)
self._init_parameters = self.remove_extra_parameters(locals())
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.hf_diffusion = hf_diffusion
self.torch_dtype = torch_dtype
self.model_key = prepare_model(
model_type='sdxl-prompt-to-prompt',
pretrained_model_name_or_path=hf_diffusion,
pipe_func=p2p_pipeline.Prompt2PromptPipeline,
torch_dtype=torch_dtype)
self.text_key_second = text_key_second
self.output_dir = output_dir
if text_key is not None:
self.text_key = text_key
[docs]
def process_single(self, sample, rank=None, context=False):
if self.text_key_second is None:
logger.error('This OP (sdxl_prompt2prompt_mapper) requires \
processing multiple fields, and you need to specify \
valid `text_key_second`')
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir, exist_ok=True)
random_num = str(random.randint(0, 9999))
t1 = time.localtime()
t2 = time.strftime('%Y-%m-%d-%H-%M-%S', t1)
model = get_model(model_key=self.model_key,
rank=rank,
use_cuda=self.use_cuda())
seed = random.randint(0, 9999)
g_cpu = torch.Generator().manual_seed(seed)
cross_attention_kwargs = {
'edit_type': 'refine',
'n_self_replace': 0.4,
'n_cross_replace': {
'default_': 1.0,
'confetti': 0.8
},
}
with torch.no_grad():
prompts = [sample[self.text_key], sample[self.text_key_second]]
image = model(prompts,
cross_attention_kwargs=cross_attention_kwargs,
guidance_scale=self.guidance_scale,
num_inference_steps=self.num_inference_steps,
generator=g_cpu)
new_output_dir = transfer_data_dir(self.output_dir, OP_NAME)
for idx, img in enumerate(image['images']):
img_id = str(idx + 1)
image_name = 'image_pair_' + t2 + '_' + random_num + '_' + img_id + '.jpg'
abs_image_path = os.path.join(new_output_dir, image_name)
img.save(abs_image_path)
sample['image_path' + img_id] = abs_image_path
return sample