Source code for data_juicer.ops.mapper.generate_qa_from_text_mapper

import re
from typing import Dict, Optional

from loguru import logger

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

torch = LazyLoader('torch', 'torch')
vllm = LazyLoader('vllm', 'vllm')

OP_NAME = 'generate_qa_from_text_mapper'


# TODO: Extend LLM-based OPs into API-based implementation.
[docs] @OPERATORS.register_module(OP_NAME) class GenerateQAFromTextMapper(Mapper): """ Mapper to generate question and answer pairs from text. Recommended model list: [ 'alibaba-pai/pai-llama3-8b-doc2qa', 'alibaba-pai/pai-baichuan2-7b-doc2qa', 'alibaba-pai/pai-qwen1_5-4b-doc2qa', 'alibaba-pai/pai-qwen1_5-7b-doc2qa', 'alibaba-pai/pai-qwen1_5-1b8-doc2qa', 'alibaba-pai/pai-qwen1_5-0b5-doc2qa' ] These recommended models are all trained with Chinese data and are suitable for Chinese. """ _accelerator = 'cuda' _batched_op = True
[docs] def __init__(self, hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa', *, output_pattern: Optional[str] = None, enable_vllm: bool = False, model_params: Optional[Dict] = None, sampling_params: Optional[Dict] = None, **kwargs): """ Initialization method. :param hf_model: Hugginface model ID. :param output_pattern: Regular expression pattern to extract questions and answers from model response. :param enable_vllm: Whether to use vllm for inference acceleration. :param model_params: Parameters for initializing the model. :param sampling_params: Sampling parameters for text generation, e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. The default data format parsed by this interface is as follows: Model Input: 蒙古国的首都是乌兰巴托(Ulaanbaatar) 冰岛的首都是雷克雅未克(Reykjavik) Model Output: 蒙古国的首都是乌兰巴托(Ulaanbaatar) 冰岛的首都是雷克雅未克(Reykjavik) Human: 请问蒙古国的首都是哪里? Assistant: 你好,根据提供的信息,蒙古国的首都是乌兰巴托(Ulaanbaatar)。 Human: 冰岛的首都是哪里呢? Assistant: 冰岛的首都是雷克雅未克(Reykjavik)。 ... """ super().__init__(**kwargs) if output_pattern is None: self.output_pattern = r'Human:(.*?)Assistant:(.*?)(?=Human|$)' # noqa: E501 else: self.output_pattern = output_pattern self.enable_vllm = enable_vllm model_params = model_params or {} sampling_params = sampling_params or {} if enable_vllm: assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' # cannot initialize vllm replicas on different GPUs self.num_proc = 1 if model_params.get('tensor_parallel_size') is None: tensor_parallel_size = torch.cuda.device_count() logger.info(f'Set tensor_parallel_size to \ {tensor_parallel_size} for vllm.') model_params['tensor_parallel_size'] = tensor_parallel_size self.model_key = prepare_model( model_type='vllm', pretrained_model_name_or_path=hf_model, **model_params) self.sampling_params = vllm.SamplingParams(**sampling_params) else: self.model_key = prepare_model( model_type='huggingface', pretrained_model_name_or_path=hf_model, return_pipe=True, **model_params) self.sampling_params = sampling_params
[docs] def parse_output(self, raw_output): logger.debug(raw_output) qa_list = [] matches = re.findall(self.output_pattern, raw_output, re.DOTALL) for match in matches: user, assistant = match qa_list.append((user.strip(), assistant.strip())) return qa_list
[docs] def process_batched(self, samples, rank=None): model, _ = get_model(self.model_key, rank, self.use_cuda()) input_keys = samples.keys() num_samples = len(samples[next(iter(input_keys))]) output_keys = input_keys | {self.query_key, self.response_key} output_samples = {key: [] for key in output_keys} for i in range(num_samples): messages = [{'role': 'user', 'content': samples[self.text_key][i]}] if self.enable_vllm: response = model.chat(messages, self.sampling_params) output = response[0].outputs[0].text else: # model is pipe response = model(messages, return_full_text=False, **self.sampling_params) output = response[0]['generated_text'] qa_list = self.parse_output(output) if len(qa_list) > 0: for q, a in qa_list: for input_k in input_keys: output_samples[input_k].append(samples[input_k][i]) output_samples[self.query_key].append(q) output_samples[self.response_key].append(a) else: logger.warning( 'No question and answer was extracted from current sample!' ) return output_samples