[docs]@OPERATORS.register_module(OP_NAME)@LOADED_IMAGES.register_module(OP_NAME)classSDXLPrompt2PromptMapper(Mapper):""" Generate pairs of similar images by the SDXL model """_accelerator='cuda'
[docs]def__init__(self,hf_diffusion:str='stabilityai/stable-diffusion-xl-base-1.0',trust_remote_code=False,torch_dtype:str='fp32',num_inference_steps:float=50,guidance_scale:float=7.5,text_key=None,text_key_second=None,output_dir=DATA_JUICER_ASSETS_CACHE,*args,**kwargs):""" Initialization method. :param hf_diffusion: diffusion model name on huggingface to generate the image. :param torch_dtype: the floating point type used to load the diffusion model. :param num_inference_steps: The larger the value, the better the image generation quality; however, this also increases the time required for generation. :param guidance_scale: A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. Guidance scale is enabled when :param text_key: the key name used to store the first caption in the caption pair. :param text_key_second: the key name used to store the second caption in the caption pair. :param output_dir: the storage location of the generated images. """kwargs.setdefault('mem_required','38GB')super().__init__(*args,**kwargs)self._init_parameters=self.remove_extra_parameters(locals())self.num_inference_steps=num_inference_stepsself.guidance_scale=guidance_scaleself.hf_diffusion=hf_diffusionself.torch_dtype=torch_dtypeself.model_key=prepare_model(model_type='sdxl-prompt-to-prompt',pretrained_model_name_or_path=hf_diffusion,pipe_func=p2p_pipeline.Prompt2PromptPipeline,torch_dtype=torch_dtype)self.text_key_second=text_key_secondself.output_dir=output_diriftext_keyisnotNone:self.text_key=text_key
[docs]defprocess_single(self,sample,rank=None,context=False):ifself.text_key_secondisNone:logger.error('This OP (sdxl_prompt2prompt_mapper) requires \ processing multiple fields, and you need to specify \ valid `text_key_second`')ifnotos.path.exists(self.output_dir):os.makedirs(self.output_dir,exist_ok=True)random_num=str(random.randint(0,9999))t1=time.localtime()t2=time.strftime('%Y-%m-%d-%H-%M-%S',t1)model=get_model(model_key=self.model_key,rank=rank,use_cuda=self.use_cuda())seed=random.randint(0,9999)g_cpu=torch.Generator().manual_seed(seed)cross_attention_kwargs={'edit_type':'refine','n_self_replace':0.4,'n_cross_replace':{'default_':1.0,'confetti':0.8},}withtorch.no_grad():prompts=[sample[self.text_key],sample[self.text_key_second]]image=model(prompts,cross_attention_kwargs=cross_attention_kwargs,guidance_scale=self.guidance_scale,num_inference_steps=self.num_inference_steps,generator=g_cpu)new_output_dir=transfer_data_dir(self.output_dir,OP_NAME)foridx,imginenumerate(image['images']):img_id=str(idx+1)image_name='image_pair_'+t2+'_'+random_num+'_'+img_id+'.jpg'abs_image_path=os.path.join(new_output_dir,image_name)img.save(abs_image_path)sample['image_path'+img_id]=abs_image_pathreturnsample