Source code for data_juicer.ops.mapper.relation_identity_mapper

import re
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.common_utils import nested_access, nested_set
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'relation_identity_mapper'


# TODO: LLM-based inference.
[docs] @OPERATORS.register_module(OP_NAME) class RelationIdentityMapper(Mapper): """ identify relation between two entity in the text. """ DEFAULT_SYSTEM_PROMPT_TEMPLATE = ( '给定关于{entity1}{entity2}的文本信息。' '判断{entity1}{entity2}之间的关系。\n' '要求:\n' '- 关系用一个或多个词语表示,必要时可以加一个形容词来描述这段关系\n' '- 输出关系时不要参杂任何标点符号\n' '- 需要你进行合理的推理才能得出结论\n' '- 如果两个人物身份是同一个人,输出关系为:另一个身份\n' '- 输出格式为:\n' '分析推理:...\n' '所以{entity2}{entity1}的:...\n' '- 注意输出的是{entity2}{entity1}的什么关系,而不是{entity1}{entity2}的什么关系') DEFAULT_INPUT_TEMPLATE = '关于{entity1}{entity2}的文本信息:\n```\n{text}\n```\n' DEFAULT_OUTPUT_PATTERN_TEMPLATE = r""" \s*分析推理:\s*(.*?)\s* \s*所以{entity2}{entity1}的:\s*(.*?)\Z """
[docs] def __init__(self, api_model: str = 'gpt-4o', source_entity: str = None, target_entity: str = None, input_key: str = None, output_key: str = None, *, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt_template: Optional[str] = None, input_template: Optional[str] = None, output_pattern_template: Optional[str] = None, try_num: PositiveInt = 3, drop_text: bool = False, model_params: Dict = {}, sampling_params: Dict = {}, **kwargs): """ Initialization method. :param api_model: API model name. :param source_entity: The source entity of the relation to be identified. :param target_entity: The target entity of the relation to be identified. :param input_key: The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. :param output_key: The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is input_key in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. :param system_prompt_template: System prompt template for the task. :param input_template: Template for building the model input. :param output_pattern_template: Regular expression template for parsing model output. :param try_num: The number of retry attempts when there is an API call error or output parsing error. :param drop_text: If drop the text in the output. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) if source_entity is None or target_entity is None: logger.warning('source_entity and target_entity cannot be None') self.source_entity = source_entity self.target_entity = target_entity self.input_key = input_key or self.text_key self.output_key = output_key or self.input_key system_prompt_template = system_prompt_template or \ self.DEFAULT_SYSTEM_PROMPT_TEMPLATE self.system_prompt = system_prompt_template.format( entity1=source_entity, entity2=target_entity) self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE output_pattern_template = output_pattern_template or \ self.DEFAULT_OUTPUT_PATTERN_TEMPLATE self.output_pattern = output_pattern_template.format( entity1=source_entity, entity2=target_entity) self.sampling_params = sampling_params self.model_key = prepare_model(model_type='api', model=api_model, endpoint=api_endpoint, response_path=response_path, **model_params) self.try_num = try_num self.drop_text = drop_text
[docs] def parse_output(self, raw_output): pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) matches = pattern.findall(raw_output) relation = '' for match in matches: _, relation = match relation = relation.strip() return relation
[docs] def process_single(self, sample, rank=None): client = get_model(self.model_key, rank=rank) text = nested_access(sample, self.input_key) input_prompt = self.input_template.format(entity1=self.source_entity, entity2=self.target_entity, text=text) messages = [{ 'role': 'system', 'content': self.system_prompt }, { 'role': 'user', 'content': input_prompt }] relation = '' for i in range(self.try_num): try: output = client(messages, **self.sampling_params) relation = self.parse_output(output) if len(relation) > 0: break except Exception as e: logger.warning(f'Exception: {e}') sample = nested_set(sample, self.output_key, relation) if self.drop_text: sample.pop(self.text_key) return sample