[docs]@TAGGING_OPS.register_module(OP_NAME)@OPERATORS.register_module(OP_NAME)classQueryIntentDetectionMapper(Mapper):"""Predicts the user's intent label and corresponding score for a given query. The operator uses a Hugging Face model to classify the intent of the input query. If the query is in Chinese, it can optionally be translated to English using another Hugging Face translation model before classification. The predicted intent label and its confidence score are stored in the meta field with the keys 'query_intent_label' and 'query_intent_score', respectively. If these keys already exist in the meta field, the operator will skip processing for those samples."""_accelerator="cuda"_batched_op=True
[docs]def__init__(self,hf_model:str="bespin-global/klue-roberta-small-3i4k-intent-classification",# noqa: E501 E131zh_to_en_hf_model:Optional[str]="Helsinki-NLP/opus-mt-zh-en",model_params:Dict={},zh_to_en_model_params:Dict={},*,label_key:str=MetaKeys.query_intent_label,score_key:str=MetaKeys.query_intent_score,**kwargs,):""" Initialization method. :param hf_model: Huggingface model ID to predict intent label. :param zh_to_en_hf_model: Translation model from Chinese to English. If not None, translate the query from Chinese to English. :param model_params: model param for hf_model. :param zh_to_en_model_params: model param for zh_to_hf_model. :param label_key: The key name in the meta field to store the output label. It is 'query_intent_label' in default. :param score_key: The key name in the meta field to store the corresponding label score. It is 'query_intent_label_score' in default. :param kwargs: Extra keyword arguments. """super().__init__(**kwargs)self.label_key=label_keyself.score_key=score_keyself.model_key=prepare_model(model_type="huggingface",pretrained_model_name_or_path=hf_model,return_pipe=True,pipe_task="text-classification",**model_params,)ifzh_to_en_hf_modelisnotNone:self.zh_to_en_model_key=prepare_model(model_type="huggingface",pretrained_model_name_or_path=zh_to_en_hf_model,return_pipe=True,pipe_task="translation",**zh_to_en_model_params,)else:self.zh_to_en_model_key=None