Source code for data_juicer.ops.mapper.extract_support_text_mapper
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = "extract_support_text_mapper"
# TODO: LLM-based inference.
[docs]
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractSupportTextMapper(Mapper):
"""
Extract support sub text for a summary.
"""
DEFAULT_SYSTEM_PROMPT = (
"你将扮演一个文本摘录助手的角色。你的主要任务是基于给定"
"的文章(称为“原文”)以及对原文某个部分的简短描述或总结"
"(称为“总结”),准确地识别并提取出与该总结相对应的原文"
"片段。\n"
"要求:\n"
"- 你需要尽可能精确地匹配到最符合总结内容的那部分内容\n"
"- 如果存在多个可能的答案,请选择最贴近总结意思的那个\n"
"- 下面是一个例子帮助理解这一过程:\n"
"### 原文:\n"
"《红楼梦》是中国古典小说四大名著之一,由清代作家曹雪芹创"
"作。它讲述了贾宝玉、林黛玉等人的爱情故事及四大家族的兴衰"
"历程。书中通过复杂的人物关系展现了封建社会的各种矛盾冲突"
"。其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二"
"姐之间的争斗,生动描绘了权力争夺下的女性形象。此外,《红"
"楼梦》还以其精美的诗词闻名,这些诗词不仅增添了文学色彩,"
"也深刻反映了人物的性格特点和命运走向。\n\n"
"### 总结:\n"
"描述了书中的两个女性角色之间围绕权力展开的竞争。\n\n"
"### 原文摘录:\n"
"其中关于贾府内部斗争的部分尤其精彩,特别是王熙凤与尤二姐"
"之间的争斗,生动描绘了权力争夺下的女性形象。"
)
DEFAULT_INPUT_TEMPLATE = "### 原文:\n{text}\n\n" "### 总结:\n{summary}\n\n" "### 原文摘录:\n"
[docs]
def __init__(
self,
api_model: str = "gpt-4o",
*,
summary_key: str = MetaKeys.event_description,
support_text_key: str = MetaKeys.support_text,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs,
):
"""
Initialization method.
:param api_model: API model name.
:param summary_key: The key name to store the input summary in the
meta field. It's "event_description" in default.
:param support_text_key: The key name to store the output
support text for the summary in the meta field. It's
"support_text" in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param input_template: Template for building the model input.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.summary_key = summary_key
self.support_text_key = support_text_key
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.sampling_params = sampling_params
self.model_key = prepare_model(
model_type="api", model=api_model, endpoint=api_endpoint, response_path=response_path, **model_params
)
self.try_num = try_num
self.drop_text = drop_text
[docs]
def process_single(self, sample, rank=None):
# check if it's generated already
if self.support_text_key in sample[Fields.meta]:
return sample
client = get_model(self.model_key, rank=rank)
if self.summary_key not in sample[Fields.meta]:
logger.warning(f"{self.summary_key} does not exist in the meta field!")
return sample
summary = sample[Fields.meta][self.summary_key]
if not isinstance(summary, str):
logger.warning("Invalid input summary!")
return sample
input_prompt = self.input_template.format(text=sample[self.text_key], summary=summary)
messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": input_prompt}]
support_text = ""
for i in range(self.try_num):
try:
response = client(messages, **self.sampling_params)
support_text = response.strip()
if len(support_text) > 0:
break
except Exception as e:
logger.warning(f"Exception: {e}")
# default to summary if return None
if not support_text:
support_text = summary
sample[Fields.meta][self.support_text_key] = support_text
return sample