Source code for data_juicer.ops.mapper.extract_qa_mapper

import json
import re
from typing import Dict, Optional

from loguru import logger

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

torch = LazyLoader('torch', 'torch')
vllm = LazyLoader('vllm', 'vllm')

OP_NAME = 'extract_qa_mapper'


# TODO: Extend LLM-based OPs into API-based implementation.
[docs]@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractQAMapper(Mapper): """ Mapper to extract question and answer pair from text samples. Recommended model list: [ 'alibaba-pai/pai-llama3-8b-doc2qa', 'alibaba-pai/pai-baichuan2-7b-doc2qa', 'alibaba-pai/pai-qwen1_5-4b-doc2qa', 'alibaba-pai/pai-qwen1_5-7b-doc2qa', 'alibaba-pai/pai-qwen1_5-1b8-doc2qa', 'alibaba-pai/pai-qwen1_5-0b5-doc2qa' ] These recommended models are all trained with Chinese data and are suitable for Chinese. """ _accelerator = 'cuda'
[docs] def __init__(self, hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa', trust_remote_code: bool = False, pattern: Optional[str] = None, qa_format: str = 'chatml', enable_vllm: bool = True, tensor_parallel_size: Optional[int] = None, max_model_len: Optional[int] = None, max_num_seqs: int = 256, sampling_params: Dict = {}, *args, **kwargs): """ Initialization method. :param hf_model: Hugginface model id. :param trust_remote_code: passed to transformers :param pattern: regular expression pattern to search for within text. :param qa_format: Output format of question and answer pair. :param enable_vllm: Whether to use vllm for inference acceleration. :param tensor_parallel_size: It is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism. :param max_model_len: It is only valid when enable_vllm is True. Model context length. If unspecified, will be automatically derived from the model config. :param max_num_seqs: It is only valid when enable_vllm is True. Maximum number of sequences to be processed in a single iteration. :param sampling_params: Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} :param args: extra args :param kwargs: extra args The default data format parsed by this interface is as follows: Model Input: 蒙古国的首都是乌兰巴托(Ulaanbaatar) 冰岛的首都是雷克雅未克(Reykjavik) Model Output: 蒙古国的首都是乌兰巴托(Ulaanbaatar) 冰岛的首都是雷克雅未克(Reykjavik) Human: 请问蒙古国的首都是哪里? Assistant: 你好,根据提供的信息,蒙古国的首都是乌兰巴托(Ulaanbaatar)。 Human: 冰岛的首都是哪里呢? Assistant: 冰岛的首都是雷克雅未克(Reykjavik)。 ... """ super().__init__(*args, **kwargs) self.num_proc = 1 if pattern is None: self.pattern = r'Human: (.*?)\nAssistant: (.*?)(?=\nHuman|$)' else: self.pattern = pattern self.qa_format = qa_format self.enable_vllm = enable_vllm if enable_vllm: assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' if not tensor_parallel_size: tensor_parallel_size = torch.cuda.device_count() logger.info(f'Set tensor_parallel_size to \ {tensor_parallel_size} for vllm.') self.model_key = prepare_model( model_type='vllm', pretrained_model_name_or_path=hf_model, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, max_num_seqs=max_num_seqs) self.sampling_params = vllm.SamplingParams(**sampling_params) else: self.model_key = prepare_model( model_type='huggingface', pretrained_model_name_or_path=hf_model, trust_remote_code=trust_remote_code) self.sampling_params = sampling_params
def _extract_qa(self, output): """Extract qestion and answer pair from model output response.""" qa_list = [] pat = re.compile(self.pattern, re.DOTALL) qa_pairs = pat.findall(output) for _, qa in enumerate(qa_pairs, 1): user, assistant = qa qa_list.append((user.strip(), assistant.strip())) return qa_list
[docs] def process_single(self, sample, rank=None): model, processor = get_model(self.model_key, rank, self.use_cuda()) if self.enable_vllm: response = model.generate([sample[self.text_key]], self.sampling_params) output = response[0].outputs[0].text else: inputs = processor(sample[self.text_key], return_tensors='pt').to(model.device) response = model.generate(**inputs, **self.sampling_params) output = processor.decode(response.cpu()[0], skip_special_tokens=True) qa_list = self._extract_qa(output) if not len(qa_list): logger.info( 'No question and answer data was extracted from this sample!') dialogue_data = [] if self.qa_format == 'chatml': for qa in qa_list: dialogue_data.append({ 'messages': [{ 'role': 'user', 'content': qa[0] }, { 'role': 'assistant', 'content': qa[1] }] }) else: raise ValueError(f'Not support {self.qa_format}!') sample[self.text_key] = json.dumps(dialogue_data, ensure_ascii=False) return sample