Source code for data_juicer.ops.mapper.extract_nickname_mapper
import re
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'extract_nickname_mapper'
# TODO: LLM-based inference.
[docs]
@OPERATORS.register_module(OP_NAME)
class ExtractNicknameMapper(Mapper):
"""
Extract nickname relationship in the text.
"""
DEFAULT_SYSTEM_PROMPT = ('给定你一段文本,你的任务是将人物之间的称呼方式(昵称)提取出来。\n'
'要求:\n'
'- 需要给出说话人对被称呼人的称呼,不要搞反了。\n'
'- 相同的说话人和被称呼人最多给出一个最常用的称呼。\n'
'- 请不要输出互相没有昵称的称呼方式。\n'
'- 输出格式如下:\n'
'```\n'
'### 称呼方式1\n'
'- **说话人**:...\n'
'- **被称呼人**:...\n'
'- **...对...的昵称**:...\n'
'### 称呼方式2\n'
'- **说话人**:...\n'
'- **被称呼人**:...\n'
'- **...对...的昵称**:...\n'
'### 称呼方式3\n'
'- **说话人**:...\n'
'- **被称呼人**:...\n'
'- **...对...的昵称**:...\n'
'...\n'
'```\n')
DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n'
DEFAULT_OUTPUT_PATTERN = r"""
\#\#\#\s*称呼方式(\d+)\s*
-\s*\*\*说话人\*\*\s*:\s*(.*?)\s*
-\s*\*\*被称呼人\*\*\s*:\s*(.*?)\s*
-\s*\*\*(.*?)对(.*?)的昵称\*\*\s*:\s*(.*?)(?=\#\#\#|\Z) # for double check
"""
[docs]
def __init__(self,
api_model: str = 'gpt-4o',
*,
nickname_key: str = Fields.nickname,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
output_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param nickname_key: The field name to store the nickname
relationship. It's "__dj__nickname__" in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param input_template: Template for building the model input.
:param output_pattern: Regular expression for parsing model output.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.nickname_key = nickname_key
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.drop_text = drop_text
[docs]
def parse_output(self, raw_output):
pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
nickname_relations = []
for match in matches:
_, role1, role2, role1_tmp, role2_tmp, nickname = match
# for double check
if role1.strip() != role1_tmp.strip() or role2.strip(
) != role2_tmp.strip():
continue
role1 = role1.strip()
role2 = role2.strip()
nickname = nickname.strip()
# is name but not nickname
if role2 == nickname:
continue
if role1 and role2 and nickname:
nickname_relations.append((role1, role2, nickname))
nickname_relations = list(set(nickname_relations))
nickname_relations = [{
Fields.source_entity: nr[0],
Fields.target_entity: nr[1],
Fields.relation_description: nr[2],
Fields.relation_keywords: ['nickname'],
Fields.relation_strength: None
} for nr in nickname_relations]
return nickname_relations
[docs]
def process_single(self, sample, rank=None):
client = get_model(self.model_key, rank=rank)
input_prompt = self.input_template.format(text=sample[self.text_key])
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': input_prompt
}]
nickname_relations = []
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
nickname_relations = self.parse_output(output)
if len(nickname_relations) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
sample[self.nickname_key] = nickname_relations
if self.drop_text:
sample.pop(self.text_key)
return sample