Source code for data_juicer.ops.mapper.generate_qa_from_text_mapper
importrefromtypingimportDict,OptionalfromloguruimportloggerfrompydanticimportPositiveIntfromdata_juicer.ops.base_opimportOPERATORS,Mapperfromdata_juicer.utils.lazy_loaderimportLazyLoaderfromdata_juicer.utils.model_utilsimport(get_model,prepare_model,update_sampling_params)torch=LazyLoader('torch')vllm=LazyLoader('vllm')OP_NAME='generate_qa_from_text_mapper'# TODO: Extend LLM-based OPs into API-based implementation.
[docs]@OPERATORS.register_module(OP_NAME)classGenerateQAFromTextMapper(Mapper):""" Mapper to generate question and answer pairs from text. Recommended model list: [ 'alibaba-pai/pai-llama3-8b-doc2qa', 'alibaba-pai/pai-baichuan2-7b-doc2qa', 'alibaba-pai/pai-qwen1_5-4b-doc2qa', 'alibaba-pai/pai-qwen1_5-7b-doc2qa', 'alibaba-pai/pai-qwen1_5-1b8-doc2qa', 'alibaba-pai/pai-qwen1_5-0b5-doc2qa' ] These recommended models are all trained with Chinese data and are suitable for Chinese. """_accelerator='cuda'_batched_op=True
[docs]def__init__(self,hf_model:str='alibaba-pai/pai-qwen1_5-7b-doc2qa',max_num:Optional[PositiveInt]=None,*,output_pattern:Optional[str]=None,enable_vllm:bool=False,model_params:Optional[Dict]=None,sampling_params:Optional[Dict]=None,**kwargs):""" Initialization method. :param hf_model: Huggingface model ID. :param max_num: The max num of returned QA sample for each text. Not limit if it is None. :param output_pattern: Regular expression pattern to extract questions and answers from model response. :param enable_vllm: Whether to use vllm for inference acceleration. :param model_params: Parameters for initializing the model. :param sampling_params: Sampling parameters for text generation, e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. The default data format parsed by this interface is as follows: Model Input: 蒙古国的首都是乌兰巴托(Ulaanbaatar) 冰岛的首都是雷克雅未克(Reykjavik) Model Output: 蒙古国的首都是乌兰巴托(Ulaanbaatar) 冰岛的首都是雷克雅未克(Reykjavik) Human: 请问蒙古国的首都是哪里? Assistant: 你好,根据提供的信息,蒙古国的首都是乌兰巴托(Ulaanbaatar)。 Human: 冰岛的首都是哪里呢? Assistant: 冰岛的首都是雷克雅未克(Reykjavik)。 ... """super().__init__(**kwargs)self.max_num=max_numifoutput_patternisNone:self.output_pattern=r'Human:(.*?)Assistant:(.*?)(?=Human|$)'# noqa: E501else:self.output_pattern=output_patternself.enable_vllm=enable_vllmmodel_params=model_paramsor{}sampling_params=sampling_paramsor{}sampling_params=update_sampling_params(sampling_params,hf_model,self.enable_vllm)ifenable_vllm:asserttorch.cuda.device_count()>=1,'must be executed in CUDA'# cannot initialize vllm replicas on different GPUsself.num_proc=1ifmodel_params.get('tensor_parallel_size')isNone:tensor_parallel_size=torch.cuda.device_count()logger.info(f'Set tensor_parallel_size to \{tensor_parallel_size} for vllm.')model_params['tensor_parallel_size']=tensor_parallel_sizeself.model_key=prepare_model(model_type='vllm',pretrained_model_name_or_path=hf_model,**model_params)self.sampling_params=vllm.SamplingParams(**sampling_params)else:self.model_key=prepare_model(model_type='huggingface',pretrained_model_name_or_path=hf_model,return_pipe=True,**model_params)self.sampling_params=sampling_params
[docs]defprocess_batched(self,samples,rank=None):model,_=get_model(self.model_key,rank,self.use_cuda())input_keys=samples.keys()num_samples=len(samples[next(iter(input_keys))])output_keys=input_keys|{self.query_key,self.response_key}output_samples={key:[]forkeyinoutput_keys}foriinrange(num_samples):messages=[{'role':'user','content':samples[self.text_key][i]}]ifself.enable_vllm:response=model.chat(messages,self.sampling_params)output=response[0].outputs[0].textelse:# model is piperesponse=model(messages,return_full_text=False,**self.sampling_params)output=response[0]['generated_text']qa_list=self.parse_output(output)ifself.max_numisnotNone:qa_list=qa_list[:self.max_num]iflen(qa_list)>0:forq,ainqa_list:forinput_kininput_keys:output_samples[input_k].append(samples[input_k][i])output_samples[self.query_key].append(q)output_samples[self.response_key].append(a)else:logger.warning('No question and answer was extracted from current sample!')returnoutput_samples