#!/usr/bin/env python3
"""
Operator Searcher - A tool for filtering Data-Juicer operators by tags
"""
import inspect
import re
from typing import Dict, List, Optional
from data_juicer.ops import OPERATORS
[文档]
class OPRecord:
"""A record class for storing operator metadata"""
[文档]
def __init__(self, op_type: str, name: str, desc: str, tags: List[str], sig: inspect.Signature, param_desc: str):
self.type = op_type
self.name = name
self.desc = desc
self.tags = tags
self.sig = sig
self.param_desc = param_desc
[文档]
def to_dict(self):
return {
"type": self.type,
"name": self.name,
"description": self.desc,
"tags": self.tags,
"signature": self.sig,
"param_desc": self.param_desc,
}
# OP tag analysis functions
[文档]
def analyze_modality_tag(code, op_prefix):
"""
Analyze the modality tag for the given code content string. Should be one
of the "Modality Tags" in `tagging_mappings.json`. It makes the choice by
finding the usages of attributes `{modality}_key` and the prefix of the OP
name. If there are multiple modality keys are used, the 'multimodal' tag
will be returned instead.
"""
tags = []
if "self.text_key" in code or op_prefix == "text":
tags.append("text")
if "self.image_key" in code or op_prefix == "image":
tags.append("image")
if "self.audio_key" in code or op_prefix == "audio":
tags.append("audio")
if "self.video_key" in code or op_prefix == "video":
tags.append("video")
if len(tags) > 1:
tags = ["multimodal"]
return tags
[文档]
def analyze_resource_tag(code):
"""
Analyze the resource tag for the given code content string. Should be one
of the "Resource Tags" in `tagging_mappings.json`. It makes the choice
according to their assigning statement to attribute `_accelerator`.
"""
if "_accelerator = " in code:
if "_accelerator = 'cuda'" in code:
return ["gpu"]
else:
return ["cpu"]
else:
return []
# def analyze_tag_from_code(content, op_name):
# """
# Analyze the tags for the OP from the given code.
# """
# tags = []
# op_prefix = op_name.split('_')[0]
# tags.extend(analyze_modality_tag(content, op_prefix))
# tags.extend(analyze_resource_tag(content))
# tags.extend(analyze_model_tags(content))
# return tags
[文档]
def analyze_tag_with_inheritance(op_cls, analyze_func, default_tags=[], other_parm=dict()):
"""
Universal inheritance chain label analysis function
"""
mro_classes = op_cls.__mro__[:3]
for cls in mro_classes:
try:
current_code = inspect.getsource(cls)
current_tags = analyze_func(current_code, **other_parm)
if len(current_tags) > 0:
return current_tags
except (OSError, TypeError):
continue
return default_tags
[文档]
def analyze_tag_from_cls(op_cls, op_name):
"""
Analyze the tags for the OP from the given cls.
"""
tags = []
op_prefix = op_name.split("_")[0]
content = inspect.getsource(op_cls)
# Try to find from the inheritance chain
resource_tags = analyze_tag_with_inheritance(op_cls, analyze_resource_tag, default_tags=["cpu"])
model_tags = analyze_tag_with_inheritance(op_cls, analyze_model_tags)
tags.extend(resource_tags)
tags.extend(model_tags)
tags.extend(analyze_modality_tag(content, op_prefix))
return tags
[文档]
class OPSearcher:
"""Operator search engine"""
[文档]
def __init__(self, specified_op_list: Optional[List[str]] = None):
if specified_op_list:
self.op_records = self._scan_specified_ops(specified_op_list)
else:
self.op_records = self._scan_all_ops()
def _scan_specified_ops(self, specified_op_list: List[str]) -> List[OPRecord]:
"""Scan specified operators"""
records = []
for op_name in specified_op_list:
op_type = op_name.split("_")[-1]
op_cls = OPERATORS.modules[op_name]
desc = op_cls.__doc__ or ""
tags = analyze_tag_from_cls(op_cls, op_name)
sig = inspect.signature(op_cls.__init__)
init_param_desc = extract_param_docstring(op_cls.__init__.__doc__ or "")
records.append(
OPRecord(op_type=op_type, name=op_name, desc=desc, tags=tags, sig=sig, param_desc=init_param_desc)
)
return records
def _scan_all_ops(self) -> List[OPRecord]:
"""Scan all operators"""
all_ops_list = list(OPERATORS.modules.keys())
return self._scan_specified_ops(all_ops_list)
[文档]
def search(
self, tags: Optional[List[str]] = None, op_type: Optional[str] = None, match_all: bool = True
) -> List[Dict]:
"""
Search operators by criteria
:param tags: List of tags to match
:param op_type: Operator type (mapper/filter/etc)
:param match_all: True requires matching all tags, False matches any tag
:return: List of matched operator records
"""
results = []
for record in self.op_records:
# Filter by type
if op_type and record.type != op_type:
continue
# Filter by tags
if tags:
tags = [tag.lower() for tag in tags]
if match_all:
if not all(tag in record.tags for tag in tags):
continue
else:
if not any(tag in record.tags for tag in tags):
continue
results.append(record.to_dict())
return results
[文档]
def main(tags, op_type):
searcher = OPSearcher()
results = searcher.search(tags=tags, op_type=op_type)
print(f"\nFound {len(results)} operators:")
for op in results:
print(f"\n[{op['type'].upper()}] {op['name']}")
print(f"Tags: {', '.join(op['tags'])}")
if __name__ == "__main__":
tags = []
op_type = "filter"
main(tags, op_type=op_type)