import re
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, Aggregator
from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
is_string_list, nested_access,
nested_set)
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
from .nested_aggregator import NestedAggregator
torch = LazyLoader('torch', 'torch')
vllm = LazyLoader('vllm', 'vllm')
OP_NAME = 'entity_attribute_aggregator'
# TODO: LLM-based inference.
[docs]
@OPERATORS.register_module(OP_NAME)
class EntityAttributeAggregator(Aggregator):
"""
Return conclusion of the given entity's attribute from some docs.
"""
DEFAULT_SYSTEM_TEMPLATE = (
'给定与`{entity}`相关的一些文档,总结`{entity}`的`{attribute}`。\n'
'要求:\n'
'- 尽量使用原文专有名词\n'
'- 联系上下文,自动忽略上下文不一致的细节错误\n'
'- 只对文档中与`{entity}`的`{attribute}`有关的内容进行总结\n'
'- 字数限制在**{word_limit}字以内**\n'
'- 要求输出格式如下:\n'
'# {entity}\n'
'## {attribute}\n'
'...\n'
'{example}')
DEFAULT_EXAMPLE_PROMPT = ('- 例如,根据相关文档总结`孙悟空`的`出身背景`,**100字**以内的样例如下:\n'
'`孙悟空`的`出身背景`总结:\n'
'# 孙悟空\n'
'## 出身背景\n'
'号称齐天大圣,花果山水帘洞的美猴王、西行取经队伍中的大师兄。'
'师父是唐僧玄奘,曾拜菩提祖师学艺。'
'亲生父母未知,自石头中孕育而生。自认斗战胜佛,最怕观世音菩萨和紧箍咒。\n')
DEFAULT_INPUT_TEMPLATE = ('`{entity}`的相关文档:\n'
'{sub_docs}\n\n'
'`{entity}`的`{attribute}`总结:\n')
DEFAULT_OUTPUT_PATTERN_TEMPLATE = r'\#\s*{entity}\s*\#\#\s*{attribute}\s*(.*?)\Z' # noqa: E501
[docs]
def __init__(self,
api_model: str = 'gpt-4o',
entity: str = None,
attribute: str = None,
input_key: str = None,
output_key: str = None,
word_limit: PositiveInt = 100,
max_token_num: Optional[PositiveInt] = None,
*,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt_template: Optional[str] = None,
example_prompt: Optional[str] = None,
input_template: Optional[str] = None,
output_pattern_template: Optional[str] = None,
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param entity: The given entity.
:param attribute: The given attribute.
:param input_key: The input field key in the samples. Support for
nested keys such as "__dj__stats__.text_len". It is text_key
in default.
:param output_key: The output field key in the samples. Support for
nested keys such as "__dj__stats__.text_len". It is same as the
input_key in default.
:param word_limit: Prompt the output length.
:param max_token_num: The max token num of the total tokens of the
sub documents. Without limitation if it is None.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt_template: The system prompt template.
:param example_prompt: The example part in the system prompt.
:param input_template: The input template.
:param output_pattern_template: The output template.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
if entity is None or attribute is None:
raise ValueError('The entity and attribute cannot be None!')
self.entity = entity
self.attribute = attribute
self.input_key = input_key or self.text_key
self.output_key = output_key or self.input_key
self.word_limit = word_limit
self.max_token_num = max_token_num
system_prompt_template = system_prompt_template or \
self.DEFAULT_SYSTEM_TEMPLATE
self.example_prompt = example_prompt or self.DEFAULT_EXAMPLE_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
output_pattern_template = output_pattern_template or \
self.DEFAULT_OUTPUT_PATTERN_TEMPLATE
self.system_prompt = system_prompt_template.format(
entity=self.entity,
attribute=self.attribute,
word_limit=self.word_limit,
example=self.example_prompt)
self.output_pattern = output_pattern_template.format(
entity=entity, attribute=attribute)
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
return_processor=True,
**model_params)
self.try_num = try_num
self.nested_sum = NestedAggregator(model=api_model,
max_token_num=max_token_num,
api_endpoint=api_endpoint,
response_path=response_path,
try_num=try_num,
model_params=model_params,
sampling_params=sampling_params)
[docs]
def parse_output(self, response):
pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(response)
if matches:
result = matches[0].strip()
else:
result = ''
return result
[docs]
def attribute_summary(self, sub_docs, rank=None):
if not sub_docs:
return ''
model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
token_nums = [len(tokenizer.encode(sub_doc)) for sub_doc in sub_docs]
group_docs = avg_split_string_list_under_limit(sub_docs, token_nums,
self.max_token_num)
results = []
for docs in group_docs:
doc_str = '\n\n'.join(docs)
input_prompt = self.input_template.format(entity=self.entity,
attribute=self.attribute,
sub_docs=doc_str)
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': input_prompt
}]
result = ''
for i in range(self.try_num):
try:
response = model(messages, **self.sampling_params)
result = self.parse_output(response)
if len(result) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
results.append(result)
return self.nested_sum.recursive_summary(results)
[docs]
def process_single(self, sample=None, rank=None):
# if not batched sample
sub_docs = nested_access(sample, self.input_key)
if not is_string_list(sub_docs):
return sample
sample = nested_set(sample, self.output_key,
self.attribute_summary(sub_docs, rank=rank))
return sample