[docs]@OPERATORS.register_module(OP_NAME)@LOADED_IMAGES.register_module(OP_NAME)classSDXLPrompt2PromptMapper(Mapper):""" Generate pairs of similar images by the SDXL model """_accelerator="cuda"
[docs]def__init__(self,hf_diffusion:str="stabilityai/stable-diffusion-xl-base-1.0",trust_remote_code=False,torch_dtype:str="fp32",num_inference_steps:float=50,guidance_scale:float=7.5,text_key=None,text_key_second=None,output_dir=DATA_JUICER_ASSETS_CACHE,*args,**kwargs,):""" Initialization method. :param hf_diffusion: diffusion model name on huggingface to generate the image. :param torch_dtype: the floating point type used to load the diffusion model. :param num_inference_steps: The larger the value, the better the image generation quality; however, this also increases the time required for generation. :param guidance_scale: A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. Guidance scale is enabled when :param text_key: the key name used to store the first caption in the caption pair. :param text_key_second: the key name used to store the second caption in the caption pair. :param output_dir: the storage location of the generated images. """kwargs.setdefault("mem_required","38GB")super().__init__(*args,**kwargs)self._init_parameters=self.remove_extra_parameters(locals())self.num_inference_steps=num_inference_stepsself.guidance_scale=guidance_scaleself.hf_diffusion=hf_diffusionself.torch_dtype=torch_dtypeself.model_key=prepare_model(model_type="sdxl-prompt-to-prompt",pretrained_model_name_or_path=hf_diffusion,pipe_func=p2p_pipeline.Prompt2PromptPipeline,torch_dtype=torch_dtype,)self.text_key_second=text_key_secondself.output_dir=output_diriftext_keyisnotNone:self.text_key=text_key
[docs]defprocess_single(self,sample,rank=None,context=False):ifself.text_key_secondisNone:logger.error("This OP (sdxl_prompt2prompt_mapper) requires \ processing multiple fields, and you need to specify \ valid `text_key_second`")ifnotos.path.exists(self.output_dir):os.makedirs(self.output_dir,exist_ok=True)random_num=str(random.randint(0,9999))t1=time.localtime()t2=time.strftime("%Y-%m-%d-%H-%M-%S",t1)model=get_model(model_key=self.model_key,rank=rank,use_cuda=self.use_cuda())seed=random.randint(0,9999)g_cpu=torch.Generator().manual_seed(seed)cross_attention_kwargs={"edit_type":"refine","n_self_replace":0.4,"n_cross_replace":{"default_":1.0,"confetti":0.8},}withtorch.no_grad():prompts=[sample[self.text_key],sample[self.text_key_second]]image=model(prompts,cross_attention_kwargs=cross_attention_kwargs,guidance_scale=self.guidance_scale,num_inference_steps=self.num_inference_steps,generator=g_cpu,)new_output_dir=transfer_data_dir(self.output_dir,OP_NAME)foridx,imginenumerate(image["images"]):img_id=str(idx+1)image_name="image_pair_"+t2+"_"+random_num+"_"+img_id+".jpg"abs_image_path=os.path.join(new_output_dir,image_name)img.save(abs_image_path)sample["image_path"+img_id]=abs_image_pathreturnsample