Source code for data_juicer.ops.mapper.relation_identity_mapper

import re
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = "relation_identity_mapper"


# TODO: LLM-based inference.
[docs] @TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class RelationIdentityMapper(Mapper): """ identify relation between two entity in the text. """ DEFAULT_SYSTEM_PROMPT_TEMPLATE = ( "给定关于{entity1}{entity2}的文本信息。" "判断{entity1}{entity2}之间的关系。\n" "要求:\n" "- 关系用一个或多个词语表示,必要时可以加一个形容词来描述这段关系\n" "- 输出关系时不要参杂任何标点符号\n" "- 需要你进行合理的推理才能得出结论\n" "- 如果两个人物身份是同一个人,输出关系为:另一个身份\n" "- 输出格式为:\n" "分析推理:...\n" "所以{entity2}{entity1}的:...\n" "- 注意输出的是{entity2}{entity1}的什么关系,而不是{entity1}{entity2}的什么关系" ) DEFAULT_INPUT_TEMPLATE = "关于{entity1}{entity2}的文本信息:\n```\n{text}\n```\n" DEFAULT_OUTPUT_PATTERN_TEMPLATE = r""" \s*分析推理:\s*(.*?)\s* \s*所以{entity2}{entity1}的:\s*(.*?)\Z """
[docs] def __init__( self, api_model: str = "gpt-4o", source_entity: str = None, target_entity: str = None, *, output_key: str = MetaKeys.role_relation, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt_template: Optional[str] = None, input_template: Optional[str] = None, output_pattern_template: Optional[str] = None, try_num: PositiveInt = 3, drop_text: bool = False, model_params: Dict = {}, sampling_params: Dict = {}, **kwargs, ): """ Initialization method. :param api_model: API model name. :param source_entity: The source entity of the relation to be identified. :param target_entity: The target entity of the relation to be identified. :param output_key: The output key in the meta field in the samples. It is 'role_relation' in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. :param system_prompt_template: System prompt template for the task. :param input_template: Template for building the model input. :param output_pattern_template: Regular expression template for parsing model output. :param try_num: The number of retry attempts when there is an API call error or output parsing error. :param drop_text: If drop the text in the output. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) if source_entity is None or target_entity is None: logger.warning("source_entity and target_entity cannot be None") self.source_entity = source_entity self.target_entity = target_entity self.output_key = output_key system_prompt_template = system_prompt_template or self.DEFAULT_SYSTEM_PROMPT_TEMPLATE self.system_prompt = system_prompt_template.format(entity1=source_entity, entity2=target_entity) self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE output_pattern_template = output_pattern_template or self.DEFAULT_OUTPUT_PATTERN_TEMPLATE self.output_pattern = output_pattern_template.format(entity1=source_entity, entity2=target_entity) self.sampling_params = sampling_params self.model_key = prepare_model( model_type="api", model=api_model, endpoint=api_endpoint, response_path=response_path, **model_params ) self.try_num = try_num self.drop_text = drop_text
[docs] def parse_output(self, raw_output): pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) matches = pattern.findall(raw_output) relation = "" for match in matches: _, relation = match relation = relation.strip() return relation
[docs] def process_single(self, sample, rank=None): meta = sample[Fields.meta] if self.output_key in meta: return sample client = get_model(self.model_key, rank=rank) text = sample[self.text_key] input_prompt = self.input_template.format(entity1=self.source_entity, entity2=self.target_entity, text=text) messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": input_prompt}] relation = "" for i in range(self.try_num): try: output = client(messages, **self.sampling_params) relation = self.parse_output(output) if len(relation) > 0: break except Exception as e: logger.warning(f"Exception: {e}") meta[self.output_key] = relation if self.drop_text: sample.pop(self.text_key) return sample