data_juicer.ops.mapper.extract_entity_attribute_mapper 源代码
import re
from typing import Dict, List, Optional
import numpy as np
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = "extract_entity_attribute_mapper"
# TODO: LLM-based inference.
[文档]
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractEntityAttributeMapper(Mapper):
"""Extracts attributes for given entities from the text and stores them in the sample's
metadata.
This operator uses an API model to extract specified attributes for given entities from
the input text. It constructs prompts based on provided templates and parses the model's
output to extract attribute descriptions and supporting text. The extracted data is
stored in the sample's metadata under the specified keys. If the required metadata
fields already exist, the operator skips processing for that sample. The operator
retries the API call and parsing up to a specified number of times in case of errors.
The default system prompt, input template, and parsing patterns are used if not
provided."""
DEFAULT_SYSTEM_PROMPT_TEMPLATE = (
"给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n"
"要求:\n"
"- 摘录的示例应该简短。\n"
"- 遵循如下的回复格式:\n"
"# {entity}\n"
"## {attribute}:\n"
"...\n"
"### 代表性示例摘录1:\n"
"```\n"
"...\n"
"```\n"
"### 代表性示例摘录2:\n"
"```\n"
"...\n"
"```\n"
"...\n"
)
DEFAULT_INPUT_TEMPLATE = "# 文本\n```\n{text}\n```\n"
DEFAULT_ATTR_PATTERN_TEMPLATE = r"\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)"
DEFAULT_DEMON_PATTERN = r"\#\#\#\s*代表性示例摘录(\d+):\s*```\s*(.*?)```\s*(?=\#\#\#|\Z)" # noqa: E501
[文档]
def __init__(
self,
api_model: str = "gpt-4o",
query_entities: List[str] = [],
query_attributes: List[str] = [],
*,
entity_key: str = MetaKeys.main_entities,
attribute_key: str = MetaKeys.attributes,
attribute_desc_key: str = MetaKeys.attribute_descriptions,
support_text_key: str = MetaKeys.attribute_support_texts,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt_template: Optional[str] = None,
input_template: Optional[str] = None,
attr_pattern_template: Optional[str] = None,
demo_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs,
):
"""
Initialization method.
:param api_model: API model name.
:param query_entities: Entity list to be queried.
:param query_attributes: Attribute list to be queried.
:param entity_key: The key name in the meta field to store the
given main entity for attribute extraction. It's "entity" in
default.
:param attribute_key: The key name in the meta field to
store the given attribute to be extracted. It's "attribute"
in default.
:param attribute_desc_key: The key name in the meta field to store
the extracted attribute description. It's
"attribute_description" in default.
:param support_text_key: The key name in the meta field to store
the attribute support text extracted from the raw text.
It's "support_text" in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt_template: System prompt template for the
task. Need to be specified by given entity and attribute.
:param input_template: Template for building the model input.
:param attr_pattern_template: Pattern for parsing the attribute from
output. Need to be specified by given attribute.
:param demo_pattern: Pattern for parsing the demonstration from
output to support the attribute.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.query_entities = query_entities
self.query_attributes = query_attributes
self.entity_key = entity_key
self.attribute_key = attribute_key
self.attribute_desc_key = attribute_desc_key
self.support_text_key = support_text_key
self.system_prompt_template = system_prompt_template or self.DEFAULT_SYSTEM_PROMPT_TEMPLATE
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.attr_pattern_template = attr_pattern_template or self.DEFAULT_ATTR_PATTERN_TEMPLATE
self.demo_pattern = demo_pattern or self.DEFAULT_DEMON_PATTERN
self.sampling_params = sampling_params
self.model_key = prepare_model(
model_type="api", model=api_model, endpoint=api_endpoint, response_path=response_path, **model_params
)
self.try_num = try_num
self.drop_text = drop_text
[文档]
def parse_output(self, raw_output, attribute_name):
attribute_pattern = self.attr_pattern_template.format(attribute=attribute_name)
pattern = re.compile(attribute_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
if matches:
attribute = matches[0].strip()
else:
attribute = ""
pattern = re.compile(self.demo_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
demos = [demo.strip() for _, demo in matches if demo.strip()]
return attribute, demos
def _process_single_text(self, text="", rank=None):
client = get_model(self.model_key, rank=rank)
entities, attributes, descs, demo_lists = [], [], [], []
for entity in self.query_entities:
for attribute in self.query_attributes:
system_prompt = self.system_prompt_template.format(entity=entity, attribute=attribute)
input_prompt = self.input_template.format(text=text)
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": input_prompt}]
desc, demos = "", np.array([], dtype=str)
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
cur_desc, cur_demos = self.parse_output(output, attribute)
if cur_desc and len(cur_demos) > 0:
desc = cur_desc
demos = cur_demos
break
except Exception as e:
logger.warning(f"Exception: {e}")
entities.append(entity)
attributes.append(attribute)
descs.append(desc)
demo_lists.append(demos)
return entities, attributes, descs, demo_lists
[文档]
def process_single(self, sample, rank=None):
# check if it's generated already
if set([self.entity_key, self.attribute_key, self.attribute_desc_key, self.support_text_key]) <= set(
sample[Fields.meta].keys()
):
return sample
res = self._process_single_text(sample[self.text_key], rank=rank)
entities, attributes, descs, demo_lists = res
if self.drop_text:
sample.pop(self.text_key)
sample[Fields.meta][self.entity_key] = entities
sample[Fields.meta][self.attribute_key] = attributes
sample[Fields.meta][self.attribute_desc_key] = descs
sample[Fields.meta][self.support_text_key] = demo_lists
return sample