data_juicer.ops.aggregator.most_relevant_entities_aggregator module

class data_juicer.ops.aggregator.most_relevant_entities_aggregator.MostRelevantEntitiesAggregator(api_model: str = 'gpt-4o', entity: str = None, query_entity_type: str = None, input_key: str = 'event_description', output_key: str = 'most_relevant_entities', max_token_num: Annotated[int, Gt(gt=0)] | None = None, *, api_endpoint: str | None = None, response_path: str | None = None, system_prompt_template: str | None = None, input_template: str | None = None, output_pattern: str | None = None, try_num: Annotated[int, Gt(gt=0)] = 3, model_params: Dict = {}, sampling_params: Dict = {}, **kwargs)[源代码]

基类:Aggregator

Extracts and ranks entities closely related to a given entity from provided texts.

The operator uses a language model API to identify and rank entities, filtering out entities of the same type as the given entity. The ranked list is sorted in descending order of importance. Input texts are aggregated and passed to the model, with an optional token limit. The output is parsed using a regular expression to extract the relevant entities. Results are stored in the batch metadata under the key 'most_relevant_entities'. The operator retries the API call up to a specified number of times in case of errors. The system prompt, input template, and output pattern can be customized.

DEFAULT_SYSTEM_TEMPLATE = '给定与`{entity}`相关的一些文档,总结一些与`{entity}`最为相关的`{entity_type}`。\n要求:\n- 不用包含与{entity}为同一{entity_type}的{entity_type}。\n- 请按照人物的重要性进行排序,**越重要人物在列表越前面**。\n- 你的返回格式如下:\n## 分析\n你对各个{entity_type}与{entity}关联度的分析\n## 列表\n人物1, 人物2, 人物3, ...'
DEFAULT_INPUT_TEMPLATE = '`{entity}`的相关文档:\n{sub_docs}\n\n与`{entity}`最相关的一些`{entity_type}`:\n'
DEFAULT_OUTPUT_PATTERN = '\\#\\#\\s*列表\\s*(.*?)\\Z'
__init__(api_model: str = 'gpt-4o', entity: str = None, query_entity_type: str = None, input_key: str = 'event_description', output_key: str = 'most_relevant_entities', max_token_num: Annotated[int, Gt(gt=0)] | None = None, *, api_endpoint: str | None = None, response_path: str | None = None, system_prompt_template: str | None = None, input_template: str | None = None, output_pattern: str | None = None, try_num: Annotated[int, Gt(gt=0)] = 3, model_params: Dict = {}, sampling_params: Dict = {}, **kwargs)[源代码]

Initialization method. :param api_model: API model name. :param entity: The given entity. :param query_entity_type: The type of queried relevant entities. :param input_key: The input key in the meta field of the samples.

It is "event_description" in default.

参数:
  • output_key -- The output key in the aggregation field of the samples. It is "most_relevant_entities" in default.

  • max_token_num -- The max token num of the total tokens of the sub documents. Without limitation if it is None.

  • api_endpoint -- URL endpoint for the API.

  • response_path -- Path to extract content from the API response. Defaults to 'choices.0.message.content'.

  • system_prompt_template -- The system prompt template.

  • input_template -- The input template.

  • output_pattern -- The output pattern.

  • try_num -- The number of retry attempts when there is an API call error or output parsing error.

  • model_params -- Parameters for initializing the API model.

  • sampling_params -- Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}

  • kwargs -- Extra keyword arguments.

parse_output(response)[源代码]
query_most_relevant_entities(sub_docs, rank=None)[源代码]
process_single(sample=None, rank=None)[源代码]

For sample level, batched sample --> sample, the input must be the output of some Grouper OP.

参数:

sample -- batched sample to aggregate

返回:

aggregated sample