Source code for data_juicer.ops.mapper.pair_preference_mapper

import re
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = "pair_preference_mapper"


# TODO: Extend LLM-based OPs into API-based implementation.
[docs] @OPERATORS.register_module(OP_NAME) class PairPreferenceMapper(Mapper): """ Mapper to construct paired preference samples. """ # avoid leading whitespace DEFAULT_SYSTEM_PROMPT = ( "你的任务是根据参考信息修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。" "必须按照以下标记格式输出,不要输出其他多余内容。\n" "【回答】\n" "生成的新回答\n" "【原因】\n" "生成该回答的原因" ) DEFAULT_INPUT_TEMPLATE = ( "【参考信息】\n" "{reference}\n" "\n" "以下是原始问答对:\n" "【问题】\n" "{query}\n" "【回答】\n" "{response}" ) DEFAULT_OUTPUT_PATTERN = r".*?【回答】\s*(.*?)\s*【原因】\s*(.*)"
[docs] def __init__( self, api_model: str = "gpt-4o", *, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, input_template: Optional[str] = None, output_pattern: Optional[str] = None, rejected_key: str = "rejected_response", reason_key: str = "reason", try_num: PositiveInt = 3, model_params: Dict = {}, sampling_params: Dict = {}, **kwargs, ): """ Initialization method. :param api_model: API model name. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. :param system_prompt: System prompt for guiding the generation task. :param input_template: Template for building the model input. It must contain placeholders '{query}' and '{response}', and can optionally include '{reference}'. :param output_pattern: Regular expression for parsing model output. :param rejected_key: The field name in the sample to store the generated rejected response. Defaults to 'rejected_response'. :param reason_key: The field name in the sample to store the reason for generating the response. Defaults to 'reason'. :param try_num: The number of retries for the API call in case of response parsing failure. Defaults to 3. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN self.rejected_key = rejected_key self.reason_key = reason_key self.model_key = prepare_model( model_type="api", model=api_model, endpoint=api_endpoint, response_path=response_path, **model_params ) self.try_num = try_num self.sampling_params = sampling_params
[docs] def build_input(self, sample): mapping = { "query": sample[self.query_key], "response": sample[self.response_key], "reference": sample.get(self.text_key, ""), } return self.input_template.format_map(mapping)
[docs] def parse_output(self, raw_output): logger.debug(raw_output) match = re.match(self.output_pattern, raw_output, re.DOTALL) if match: return match.group(1).strip(), match.group(2).strip() else: return ("", "")
[docs] def process_single(self, sample, rank=None): client = get_model(self.model_key, rank=rank) messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.build_input(sample)}, ] parsed_rejected, parsed_reason = "", "" for _ in range(self.try_num): try: output = client(messages, **self.sampling_params) parsed_rejected, parsed_reason = self.parse_output(output) if parsed_rejected and parsed_reason: break except Exception as e: logger.warning(f"Exception: {e}") sample[self.rejected_key] = parsed_rejected sample[self.reason_key] = parsed_reason return sample