Source code for data_juicer.ops.mapper.extract_event_mapper
import re
from itertools import chain
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model
from ..common import split_text_by_punctuation
OP_NAME = 'extract_event_mapper'
# TODO: LLM-based inference.
[docs]@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractEventMapper(Mapper):
"""
Extract events and relevant characters in the text
"""
_batched_op = True
DEFAULT_SYSTEM_PROMPT = ('给定一段文本,对文本的情节进行分点总结,并抽取与情节相关的人物。\n'
'要求:\n'
'- 尽量不要遗漏内容,不要添加文本中没有的情节,符合原文事实\n'
'- 联系上下文说明前因后果,但仍然需要符合事实\n'
'- 不要包含主观看法\n'
'- 注意要尽可能保留文本的专有名词\n'
'- 注意相关人物需要在对应情节中出现\n'
'- 只抽取情节中的主要人物,不要遗漏情节的主要人物\n'
'- 总结格式如下:\n'
'### 情节1:\n'
'- **情节描述**: ...\n'
'- **相关人物**:人物1,人物2,人物3,...\n'
'### 情节2:\n'
'- **情节描述**: ...\n'
'- **相关人物**:人物1,人物2,...\n'
'### 情节3:\n'
'- **情节描述**: ...\n'
'- **相关人物**:人物1,...\n'
'...\n')
DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n'
DEFAULT_OUTPUT_PATTERN = r"""
\#\#\#\s*情节(\d+):\s*
-\s*\*\*情节描述\*\*\s*:\s*(.*?)\s*
-\s*\*\*相关人物\*\*\s*:\s*(.*?)(?=\#\#\#|\Z)
"""
[docs] def __init__(self,
api_model: str = 'gpt-4o',
*,
event_desc_key: str = Fields.event_description,
relevant_char_key: str = Fields.relevant_characters,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
output_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param event_desc_key: The field name to store the event descriptions.
It's "__dj__event_description__" in default.
:param relevant_char_key: The field name to store the relevant
characters to the events. It's "__dj__relevant_characters__" in
default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param input_template: Template for building the model input.
:param output_pattern: Regular expression for parsing model output.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.event_desc_key = event_desc_key
self.relevant_char_key = relevant_char_key
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.drop_text = drop_text
[docs] def parse_output(self, raw_output):
pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
event_list, character_list = [], []
for match in matches:
_, desc, chars = match
chars = split_text_by_punctuation(chars)
if len(chars) > 0:
event_list.append(desc)
character_list.append(chars)
return event_list, character_list
def _process_single_sample(self, text='', rank=None):
client = get_model(self.model_key, rank=rank)
input_prompt = self.input_template.format(text=text)
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': input_prompt
}]
event_list, character_list = [], []
for i in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
event_list, character_list = self.parse_output(output)
if len(event_list) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
return event_list, character_list
[docs] def process_batched(self, samples, rank=None):
sample_num = len(samples[self.text_key])
events, characters = [], []
for text in samples[self.text_key]:
cur_events, cur_characters = self._process_single_sample(text,
rank=rank)
events.append(cur_events)
characters.append(cur_characters)
if self.drop_text:
samples.pop(self.text_key)
for key in samples:
samples[key] = [[samples[key][i]] * len(events[i])
for i in range(sample_num)]
samples[self.event_desc_key] = events
samples[self.relevant_char_key] = characters
for key in samples:
samples[key] = list(chain(*samples[key]))
return samples