[docs]@LOADED_IMAGES.register_module(OP_NAME)@OPERATORS.register_module(OP_NAME)classMllmMapper(Mapper):"""Mapper to use MLLMs for visual question answering tasks. Recommended model list: [ llava-hf/llava-v1.6-vicuna-7b-hf, Qwen/Qwen2-VL-7B-Instruct, ] """_accelerator='cuda'
[docs]def__init__(self,hf_model:str='llava-hf/llava-v1.6-vicuna-7b-hf',max_new_tokens=256,temperature=0.2,top_p=None,num_beams=1,*args,**kwargs):""" Initialization method. :param hf_model: hugginface model id. :param max_new_tokens: the maximum number of new tokens generated by the model. :param temperature: used to control the randomness of \ generated text. The higher the temperature, the more \ random and creative the generated text will be. :param top_p: randomly select the next word from the group \ of words whose cumulative probability reaches p. :param num_beams: the larger the beam search size, the higher \ the quality of the generated text. :param args: extra args :param kwargs: extra args """torch.set_num_threads(1)kwargs.setdefault('mem_required','32GB')kwargs.setdefault('num_proc',1)super().__init__(*args,**kwargs)self.hf_model=hf_modelself.model_key=prepare_model(model_type='huggingface',pretrained_model_name_or_path=hf_model)self.max_new_tokens=max_new_tokensself.temperature=temperatureself.top_p=top_pself.num_beams=num_beams
[docs]defprocess_single(self,sample=None,rank=None):# there is no image in this sampleifself.image_keynotinsampleornotsample[self.image_key]:returnsample# load imagesloaded_image_keys=sample[self.image_key]images={}forloaded_image_keyinloaded_image_keys:ifloaded_image_keynotinimages:# avoid loading the same imagesimage=load_image(loaded_image_key)images[loaded_image_key]=imagemodel,processor=get_model(model_key=self.model_key,rank=rank,use_cuda=self.use_cuda())conversation=[{'role':'user','content':[{'type':'text','text':sample[self.text_key]},{'type':'image'},],},]prompt=processor.apply_chat_template(conversation,add_generation_prompt=True)sample[self.text_key]=[]forimage_keyinimages:inputs=processor(images=images[image_key],text=prompt,return_tensors='pt').to(model.device)response=model.generate(**inputs,max_new_tokens=self.max_new_tokens,temperature=self.temperature,top_p=self.top_p,num_beams=self.num_beams)output=processor.decode(response.cpu()[0],skip_special_tokens=True)sample[self.text_key].append(output)returnsample