[文档]@TAGGING_OPS.register_module(OP_NAME)@OPERATORS.register_module(OP_NAME)classQueryTopicDetectionMapper(Mapper):"""Predicts the topic label and its corresponding score for a given query. The input is taken from the specified query key. The output, which includes the predicted topic label and its score, is stored in the 'query_topic_label' and 'query_topic_label_score' fields of the Data-Juicer meta field. This operator uses a Hugging Face model for topic classification. If a Chinese to English translation model is provided, it will first translate the query from Chinese to English before predicting the topic. - Uses a Hugging Face model for topic classification. - Optionally translates Chinese queries to English using another Hugging Face model. - Stores the predicted topic label in 'query_topic_label'. - Stores the corresponding score in 'query_topic_label_score'."""_accelerator="cuda"_batched_op=True
[文档]def__init__(self,hf_model:str="dstefa/roberta-base_topic_classification_nyt_news",# noqa: E501 E131zh_to_en_hf_model:Optional[str]="Helsinki-NLP/opus-mt-zh-en",model_params:Dict={},zh_to_en_model_params:Dict={},*,label_key:str=MetaKeys.query_topic_label,score_key:str=MetaKeys.query_topic_score,**kwargs,):""" Initialization method. :param hf_model: Huggingface model ID to predict topic label. :param zh_to_en_hf_model: Translation model from Chinese to English. If not None, translate the query from Chinese to English. :param model_params: model param for hf_model. :param zh_to_en_model_params: model param for zh_to_hf_model. :param label_key: The key name in the meta field to store the output label. It is 'query_topic_label' in default. :param score_key: The key name in the meta field to store the corresponding label score. It is 'query_topic_label_score' in default. :param kwargs: Extra keyword arguments. """super().__init__(**kwargs)self.label_key=label_keyself.score_key=score_keyself.model_key=prepare_model(model_type="huggingface",pretrained_model_name_or_path=hf_model,return_pipe=True,pipe_task="text-classification",**model_params,)ifzh_to_en_hf_modelisnotNone:self.zh_to_en_model_key=prepare_model(model_type="huggingface",pretrained_model_name_or_path=zh_to_en_hf_model,return_pipe=True,pipe_task="translation",**zh_to_en_model_params,)else:self.zh_to_en_model_key=None