import re
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'pair_preference_mapper'
# TODO: Extend LLM-based OPs into API-based implementation.
[docs]
@OPERATORS.register_module(OP_NAME)
class PairPreferenceMapper(Mapper):
"""
Mapper to construct paired preference samples.
"""
# avoid leading whitespace
DEFAULT_SYSTEM_PROMPT = (
'你的任务是根据参考信息修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。'
'必须按照以下标记格式输出,不要输出其他多余内容。\n'
'【回答】\n'
'生成的新回答\n'
'【原因】\n'
'生成该回答的原因')
DEFAULT_INPUT_TEMPLATE = ('【参考信息】\n'
'{reference}\n'
'\n'
'以下是原始问答对:\n'
'【问题】\n'
'{query}\n'
'【回答】\n'
'{response}')
DEFAULT_OUTPUT_PATTERN = r'.*?【回答】\s*(.*?)\s*【原因】\s*(.*)'
[docs]
def __init__(self,
api_model: str = 'gpt-4o',
*,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
output_pattern: Optional[str] = None,
rejected_key: str = 'rejected_response',
reason_key: str = 'reason',
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for guiding the generation task.
:param input_template: Template for building the model input. It must
contain placeholders '{query}' and '{reponse}', and can optionally
include '{reference}'.
:param output_pattern: Regular expression for parsing model output.
:param rejected_key: The field name in the sample to store the
generated rejected response. Defaults to 'rejected_response'.
:param reason_key: The field name in the sample to store the reason for
generating the response. Defaults to 'reason'.
:param try_num: The number of retries for the API call in case of
response parsing failure. Defaults to 3.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
self.rejected_key = rejected_key
self.reason_key = reason_key
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.sampling_params = sampling_params
[docs]
def parse_output(self, raw_output):
logger.debug(raw_output)
match = re.match(self.output_pattern, raw_output, re.DOTALL)
if match:
return match.group(1).strip(), match.group(2).strip()
else:
return ('', '')
[docs]
def process_single(self, sample, rank=None):
client = get_model(self.model_key, rank=rank)
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': self.build_input(sample)
}]
parsed_rejected, parsed_reason = '', ''
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
parsed_rejected, parsed_reason = self.parse_output(output)
if parsed_rejected and parsed_reason:
break
except Exception as e:
logger.warning(f'Exception: {e}')
sample[self.rejected_key] = parsed_rejected
sample[self.reason_key] = parsed_reason
return sample