Source code for data_juicer.ops.mapper.extract_support_text_mapper
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.common_utils import nested_access, nested_set
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'extract_support_text_mapper'
# TODO: LLM-based inference.
[docs]
@OPERATORS.register_module(OP_NAME)
class ExtractSupportTextMapper(Mapper):
"""
Extract support sub text for a summary.
"""
DEFAULT_SYSTEM_PROMPT = ('你将扮演一个文本摘录助手的角色。你的主要任务是基于给定'
'的文章(称为“原文”)以及对原文某个部分的简短描述或总结'
'(称为“总结”),准确地识别并提取出与该总结相对应的原文'
'片段。\n'
'要求:\n'
'- 你需要尽可能精确地匹配到最符合总结内容的那部分内容\n'
'- 如果存在多个可能的答案,请选择最贴近总结意思的那个\n'
'- 下面是一个例子帮助理解这一过程:\n'
'### 原文:\n'
'《红楼梦》是中国古典小说四大名著之一,由清代作家曹雪芹创'
'作。它讲述了贾宝玉、林黛玉等人的爱情故事及四大家族的兴衰'
'历程。书中通过复杂的人物关系展现了封建社会的各种矛盾冲突'
'。其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二'
'姐之间的争斗,生动描绘了权力争夺下的女性形象。此外,《红'
'楼梦》还以其精美的诗词闻名,这些诗词不仅增添了文学色彩,'
'也深刻反映了人物的性格特点和命运走向。\n\n'
'### 总结:\n'
'描述了书中的两个女性角色之间围绕权力展开的竞争。\n\n'
'### 原文摘录:\n'
'其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二姐'
'之间的争斗,生动描绘了权力争夺下的女性形象。')
DEFAULT_INPUT_TEMPLATE = ('### 原文:\n{text}\n\n'
'### 总结:\n{summary}\n\n'
'### 原文摘录:\n')
[docs]
def __init__(self,
api_model: str = 'gpt-4o',
*,
summary_key: str = Fields.event_description,
support_text_key: str = Fields.support_text,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param summary_key: The field name to store the input summary.
Support for nested keys such as "__dj__stats__.text_len".
It's "__dj__event_description__" in default.
:param support_text_key: The field name to store the output
support text for the summary. It's "__dj__support_text__" in
default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param input_template: Template for building the model input.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.summary_key = summary_key
self.support_text_key = support_text_key
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.drop_text = drop_text
[docs]
def process_single(self, sample, rank=None):
client = get_model(self.model_key, rank=rank)
summary = nested_access(sample, self.summary_key)
if not isinstance(summary, str):
logger.warning('Unvalid input summary!')
return sample
input_prompt = self.input_template.format(text=sample[self.text_key],
summary=summary)
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': input_prompt
}]
support_text = ''
for i in range(self.try_num):
try:
response = client(messages, **self.sampling_params)
support_text = response.strip()
if len(support_text) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
# default to summary if return None
if not support_text:
support_text = summary
sample = nested_set(sample, self.support_text_key, support_text)
return sample