Source code for data_juicer.ops.mapper.sdxl_prompt2prompt_mapper

import logging

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.ops.op_fusion import LOADED_IMAGES
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

diffusers = LazyLoader('diffusers', 'diffusers')
torch = LazyLoader('torch', 'torch')
p2p_pipeline = LazyLoader('p2p_pipeline',

logger = logging.getLogger(__name__)

OP_NAME = 'sdxl_prompt2prompt_mapper'

[docs] @OPERATORS.register_module(OP_NAME) @LOADED_IMAGES.register_module(OP_NAME) class SDXLPrompt2PromptMapper(Mapper): """ Generate pairs of similar images by the SDXL model """ _accelerator = 'cuda'
[docs] def __init__( self, hf_diffusion: str = 'stabilityai/stable-diffusion-xl-base-1.0', trust_remote_code=False, torch_dtype: str = 'fp32', num_inference_steps: float = 50, guidance_scale: float = 7.5, text_key_second=None, text_key_third=None, *args, **kwargs): """ Initialization method. :param hf_diffusion: diffusion model name on huggingface to generate the image. :param torch_dtype: the floating point type used to load the diffusion model. :param num_inference_steps: The larger the value, the better the image generation quality; however, this also increases the time required for generation. :param guidance_scale: A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. Guidance scale is enabled when :param text_key_second: used to store the first caption in the caption pair. :param text_key_third: used to store the second caption in the caption pair. """ kwargs.setdefault('mem_required', '38GB') super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) self.num_inference_steps = num_inference_steps self.guidance_scale = guidance_scale self.hf_diffusion = hf_diffusion self.torch_dtype = torch_dtype self.model_key = prepare_model( model_type='sdxl-prompt-to-prompt', pretrained_model_name_or_path=hf_diffusion, pipe_func=p2p_pipeline.Prompt2PromptPipeline, torch_dtype=torch_dtype) self.text_key_second = text_key_second self.text_key_third = text_key_third
[docs] def process_single(self, sample, rank=None, context=False): if self.text_key_second is None: logger.error('This OP (sdxl_prompt2prompt_mapper) requires \ processing multiple fields, and you need to specify \ valid `text_key_second`') if self.text_key_third is None: logger.error('This OP (sdxl_prompt2prompt_mapper) requires \ processing multiple fields, and you need to specify \ valid `text_key_third`') model = get_model(model_key=self.model_key, rank=rank, use_cuda=self.use_cuda()) seed = 0 g_cpu = torch.Generator().manual_seed(seed) cross_attention_kwargs = { 'edit_type': 'refine', 'n_self_replace': 0.4, 'n_cross_replace': { 'default_': 1.0, 'confetti': 0.8 }, } sample[self.image_key] = [] with torch.no_grad(): prompts = [ sample[self.text_key_second], sample[self.text_key_third] ] image = model(prompts, cross_attention_kwargs=cross_attention_kwargs, guidance_scale=self.guidance_scale, num_inference_steps=self.num_inference_steps, generator=g_cpu) for idx, img in enumerate(image[self.image_key]): sample[self.image_key].append(img) return sample