import re
from typing import Dict, List, Optional
from loguru import logger
from pydantic import NonNegativeInt, PositiveInt
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'dialog_intent_detection_mapper'
# TODO: LLM-based inference.
[docs]
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class DialogIntentDetectionMapper(Mapper):
"""
Mapper to generate user's intent labels in dialog. Input from
history_key, query_key and response_key. Output lists of
labels and analysis for queries in the dialog.
"""
DEFAULT_SYSTEM_PROMPT = (
'请判断用户和LLM多轮对话中用户的意图。\n'
'要求:\n'
'- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出'
'。\n'
'用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n'
'意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n'
'意图类别:信息查找\n'
'LLM:你好!当然可以。机器学习是一种人工智能方法,允许计算机通过数据自动改进和学习。\n'
'用户:听起来很有趣,有没有推荐的入门书籍或资料?\n'
'意图分析:用户在请求建议,希望获取关于机器学习的入门资源。\n'
'意图类别:请求建议\n'
'LLM:有很多不错的入门书籍和资源。一本常被推荐的书是《Python机器学习实践》(Python'
' Machine Learning),它涵盖了基础知识和一些实际案例。此外,您还可以参考Coursera'
'或edX上的在线课程,这些课程提供了系统的学习路径。\n'
'用户:谢谢你的建议!我还想知道,学习机器学习需要什么样的数学基础?\n'
'意图分析:用户在寻求信息,希望了解学习机器学习所需的前提条件,特别是在数学方面。\n'
'意图类别:信息查找\n'
'LLM:学习机器学习通常需要一定的数学基础,特别是线性代数、概率论和统计学。这些数学领'
'域帮助理解算法的工作原理和数据模式分析。如果您对这些主题不太熟悉,建议先从相关基础'
'书籍或在线资源开始学习。\n'
'用户:明白了,我会先补习这些基础知识。再次感谢你的帮助!\n'
'意图分析:用户表达感谢,并表示计划付诸行动来补充所需的基础知识。\n'
'意图类别:其他')
DEFAULT_QUERY_TEMPLATE = '用户:{query}\n'
DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n'
DEFAULT_CANDIDATES_TEMPLATE = '备选意图类别:[{candidate_str}]'
DEFAULT_ANALYSIS_TEMPLATE = '意图分析:{analysis}\n'
DEFAULT_LABELS_TEMPLATE = '意图类别:{labels}\n'
DEFAULT_ANALYSIS_PATTERN = '意图分析:(.*?)\n'
DEFAULT_LABELS_PATTERN = '意图类别:(.*?)($|\n)'
[docs]
def __init__(self,
api_model: str = 'gpt-4o',
intent_candidates: Optional[List[str]] = None,
max_round: NonNegativeInt = 10,
*,
labels_key: str = MetaKeys.dialog_intent_labels,
analysis_key: str = MetaKeys.dialog_intent_labels_analysis,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
query_template: Optional[str] = None,
response_template: Optional[str] = None,
candidate_template: Optional[str] = None,
analysis_template: Optional[str] = None,
labels_template: Optional[str] = None,
analysis_pattern: Optional[str] = None,
labels_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param intent_candidates: The output intent candidates. Use the
intent labels of the open domain if it is None.
:param max_round: The max num of round in the dialog to build the
prompt.
:param labels_key: The key name in the meta field to store the
output labels. It is 'dialog_intent_labels' in default.
:param analysis_key: The key name in the meta field to store the
corresponding analysis. It is 'dialog_intent_labels_analysis'
in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param query_template: Template for query part to build the input
prompt.
:param response_template: Template for response part to build the
input prompt.
:param candidate_template: Template for intent candidates to
build the input prompt.
:param analysis_template: Template for analysis part to build the
input prompt.
:param labels_template: Template for labels to build the
input prompt.
:param analysis_pattern: Pattern to parse the return intent
analysis.
:param labels_pattern: Pattern to parse the return intent
labels.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.intent_candidates = intent_candidates
self.max_round = max_round
self.labels_key = labels_key
self.analysis_key = analysis_key
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
self.response_template = response_template or \
self.DEFAULT_RESPONSE_TEMPLATE
self.candidate_template = candidate_template or \
self.DEFAULT_CANDIDATES_TEMPLATE
self.analysis_template = analysis_template or \
self.DEFAULT_ANALYSIS_TEMPLATE
self.labels_template = labels_template or \
self.DEFAULT_LABELS_TEMPLATE
self.analysis_pattern = analysis_pattern or \
self.DEFAULT_ANALYSIS_PATTERN
self.labels_pattern = labels_pattern or \
self.DEFAULT_LABELS_PATTERN
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
[docs]
def parse_output(self, response):
analysis = ''
labels = ''
match = re.search(self.analysis_pattern, response)
if match:
analysis = match.group(1)
match = re.search(self.labels_pattern, response)
if match:
labels = match.group(1)
return analysis, labels
[docs]
def process_single(self, sample, rank=None):
meta = sample[Fields.meta]
if self.labels_key in meta and self.analysis_key in meta:
return sample
client = get_model(self.model_key, rank=rank)
analysis_list = []
labels_list = []
history = []
dialog = sample[self.history_key]
if self.query_key in sample and sample[self.query_key]:
if self.response_key in sample and sample[self.response_key]:
dialog.append(
(sample[self.query_key], sample[self.response_key]))
else:
dialog.append((sample[self.query_key], ''))
for qa in dialog:
input_prompt = self.build_input(history, qa)
messages = [{
'role': 'system',
'content': self.system_prompt,
}, {
'role': 'user',
'content': input_prompt,
}]
for _ in range(self.try_num):
try:
response = client(messages, **self.sampling_params)
analysis, labels = self.parse_output(response)
if len(analysis) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
analysis_list.append(analysis)
labels_list.append(labels)
history.append(self.query_template.format(query=qa[0]))
history.append(self.analysis_template.format(analysis=analysis))
history.append(self.labels_template.format(labels=labels))
history.append(self.response_template.format(response=qa[1]))
meta[self.labels_key] = labels_list
meta[self.analysis_key] = analysis_list
return sample