Source code for data_juicer.ops.mapper.generate_qa_from_examples_mapper

import json
import random
import re
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, UNFORKABLE, Mapper

torch = LazyLoader('torch', 'torch')
vllm = LazyLoader('vllm', 'vllm')
rouge = LazyLoader('rouge', 'rouge')

OP_NAME = 'generate_qa_from_examples_mapper'


# TODO: Extend LLM-based OPs into API-based implementation.
[docs]@UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class GenerateQAFromExamplesMapper(Mapper): """ Mapper to generate question and answer pairs from examples. You should configure an empty dataset in your yaml config file: ``` generated_dataset_config: type: 'EmptyFormatter' # use `RayEmptyFormatter` when enable ray length: ${The number of generated samples} feature_keys: ${text key} ``` The number of samples generated is determined by the length of the empty dataset. """ DEFAULT_SYSTEM_PROMPT = ( '请你仔细观察多个示例数据的输入和输出,按照你的理解,总结出相应规矩,然后写出一个新的【问题】和【回答】。' '注意,新生成的【问题】和【回答】需要满足如下要求:\n' '1. 生成的【问题】和【回答】不能与输入的【问题】和【回答】一致,但是需要保持格式相同。\n' '2. 生成的【问题】不一定要局限于输入【问题】的话题或领域,生成的【回答】需要正确回答生成的【问题】。\n' '3. 提供的【问题】和【回答】可能是多轮对话,生成的【问题】和【回答】也可以是多轮,但是需要保持格式相同。\n' '4. 生成的【问题】和【回答】必须成对出现,而且【问题】需要在【回答】之前。\n') DEFAULT_INPUT_TEMPLATE = '{}' DEFAULT_EXAMPLE_TEMPLATE = '\n如下是一条示例数据:\n{}' DEFAULT_QA_PAIR_TEMPLATE = '【问题】\n{}\n【回答】\n{}\n' DEFAULT_OUTPUT_PATTERN = r'【问题】(.*?)【回答】(.*?)(?=【问题】|$)' _accelerator = 'cuda'
[docs] def __init__(self, hf_model: str = 'Qwen/Qwen2.5-7B-Instruct', *, seed_file: str = '', example_num: PositiveInt = 3, similarity_threshold: float = 0.7, system_prompt: Optional[str] = None, input_template: Optional[str] = None, example_template: Optional[str] = None, qa_pair_template: Optional[str] = None, output_pattern: Optional[str] = None, enable_vllm: bool = False, model_params: Optional[Dict] = None, sampling_params: Optional[Dict] = None, **kwargs): """ Initialization method. :param hf_model: Hugginface model ID. :param seed_file: Path to the seed file in chatml format. :param example_num: The number of selected examples. Randomly select N examples from "seed_file" and put them into prompt as QA examples. :param similarity_threshold: The similarity score threshold between the generated samples and the seed examples. Range from 0 to 1. Samples with similarity score less than this threshold will be kept. :param system_prompt: System prompt for guiding the generation task. :param input_template: Template for building the input prompt. It must include one placeholder '{}', which will be replaced by `example_num` formatted examples defined by `example_template`. :param example_template: Template for formatting one QA example. It must include one placeholder '{}', which will be replaced by one formatted qa_pair. :param qa_pair_template: Template for formatting a single QA pair within each example. Must include two placeholders '{}' for the question and answer. :param output_pattern: Regular expression pattern to extract questions and answers from model response. :param enable_vllm: Whether to use vllm for inference acceleration. :param model_params: Parameters for initializing the model. :param sampling_params: Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) if not seed_file: raise ValueError( 'Please provide `seed_file` in chatml format.' 'Example: data-juicer/demos/data/demo-dataset-chatml.jsonl') self.seed_file = seed_file self.example_num = example_num self.similarity_threshold = similarity_threshold self.similarity_type = 'rouge_l' self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE self.example_template = example_template or self.DEFAULT_EXAMPLE_TEMPLATE # noqa: E501 self.qa_pair_template = qa_pair_template or \ self.DEFAULT_QA_PAIR_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN self.enable_vllm = enable_vllm model_params = model_params or {} sampling_params = sampling_params or {} if enable_vllm: assert torch.cuda.device_count() >= 1, 'must be executed in CUDA' # cannot initialize vllm replicas on different GPUs self.num_proc = 1 if model_params.get('tensor_parallel_size') is None: tensor_parallel_size = torch.cuda.device_count() logger.info(f'Set tensor_parallel_size to \ {tensor_parallel_size} for vllm.') model_params['tensor_parallel_size'] = tensor_parallel_size self.model_key = prepare_model( model_type='vllm', pretrained_model_name_or_path=hf_model, **model_params) self.sampling_params = vllm.SamplingParams(**sampling_params) else: self.model_key = prepare_model( model_type='huggingface', pretrained_model_name_or_path=hf_model, return_pipe=True, **model_params) self.sampling_params = sampling_params self.seed_qa_samples = self._load_seed_qa_samples() if len(self.seed_qa_samples) == 0: raise ValueError('No QA data was parsed from the seed file!')
def _load_seed_qa_samples(self): """Load QA pairs from chatml format file.""" qa_samples = [] with open(self.seed_file, encoding='utf-8') as f: lines = f.readlines() for line in lines: line = line.strip() qa_pairs = self._parse_chatml_str(line) if len(qa_pairs) > 0: qa_samples.append(qa_pairs) return qa_samples def _sample_to_str(self, qa_sample): return '\n'.join(['\n'.join(qa_pair) for qa_pair in qa_sample]) + '\n' def _max_rouge_l_score(self, hypothesis, references): r = rouge.Rouge() max_score = 0.0 hyp_str = self._sample_to_str(hypothesis) for reference in references: ref_str = self._sample_to_str(reference) scores = r.get_scores(hyp_str, ref_str) rouge_l_score = scores[0]['rouge-l']['f'] if rouge_l_score > max_score: max_score = rouge_l_score return max_score def _parse_chatml_str(self, sample_str): user_input = None assistant_output = None qa_pairs = [] data = json.loads(sample_str) for message in data['messages']: role = message['role'] content = message['content'] if role == 'user': user_input = content elif role == 'assistant': assistant_output = content qa_pairs.append((user_input, assistant_output)) return qa_pairs
[docs] def build_input(self, qa_examples): def format_qa_pairs(qa_example): return ''.join([ self.qa_pair_template.format(q, a) for q, a in qa_example if q and a ]) formatted_examples = ''.join([ self.example_template.format(qa_pairs=format_qa_pairs(qa_example)) for qa_example in qa_examples ]) input_prompt = self.input_template.format(examples=formatted_examples) return input_prompt
[docs] def parse_output(self, raw_output): logger.debug(raw_output) output_qa_pairs = [] matches = re.findall(self.output_pattern, raw_output, re.DOTALL) for match in matches: question, answer = match output_qa_pairs.append((question.strip(), answer.strip())) return output_qa_pairs
[docs] def process_single(self, sample, rank=None): model, _ = get_model(self.model_key, rank, self.use_cuda()) random_qa_samples = random.sample(self.seed_qa_samples, self.example_num) input_prompt = self.build_input(random_qa_samples) messages = [{ 'role': 'system', 'content': self.system_prompt }, { 'role': 'user', 'content': input_prompt }] if self.enable_vllm: response = model.chat(messages, self.sampling_params) output = response[0].outputs[0].text else: # model is pipe response = model(messages, return_full_text=False, **self.sampling_params) output = response[0]['generated_text'] output_qa_pairs = self.parse_output(output) if len(output_qa_pairs) == 0: logger.warning('Parse model response error! ' 'No data generated for the current response!') sample.update({ self.query_key: '', self.response_key: '', self.history_key: self.empty_history() }) return sample if self.similarity_type == 'rouge_l': sim_score = self._max_rouge_l_score(output_qa_pairs, random_qa_samples) else: raise ValueError( f'Not support similarity type "{self.similarity_type}"!') if sim_score <= self.similarity_threshold: query, response = output_qa_pairs[-1] history = output_qa_pairs[:-1] if len(history) == 0: history = self.empty_history() else: query = response = '' history = self.empty_history() logger.info('Filter this generated sample due to similarity.') sample.update({ self.query_key: query, self.response_key: response, self.history_key: history }) return sample