[文档]defTEST_TAG(*tags):"""Tags for test case. Currently, `standalone`, `ray` are supported. """defdecorator(func):setattr(func,'__test_tags__',tags)@functools.wraps(func)defwrapper(self,*args,**kwargs):# Save the original current_tag if it existsoriginal_tag=getattr(self,'current_tag','standalone')# Set the current_tag to the first tagiftags:self.current_tag=tags[0]try:# Run the test methodreturnfunc(self,*args,**kwargs)finally:# Restore the original current_tagself.current_tag=original_tagreturnwrapperreturndecorator
[文档]defset_clear_model_flag(flag):globalCLEAR_MODELCLEAR_MODEL=flagifCLEAR_MODEL:print('CLEAR DOWNLOADED MODELS AFTER UNITTESTS.')else:print('KEEP DOWNLOADED MODELS AFTER UNITTESTS.')
[文档]@classmethoddefsetUpClass(cls):# Set maxDiff for all test cases based on an environment variablemax_diff=os.getenv('TEST_MAX_DIFF','None')cls.maxDiff=Noneifmax_diff=='None'elseint(max_diff)importmultiprocesscls.original_mp_method=multiprocess.get_start_method()ifis_cuda_available():multiprocess.set_start_method('spawn',force=True)
[文档]@classmethoddeftearDownClass(cls,hf_model_name=None)->None:importmultiprocessmultiprocess.set_start_method(cls.original_mp_method,force=True)# clean the huggingface model cache filesifnotCLEAR_MODEL:passelifhf_model_name:# given the hf model name, remove this model onlymodel_dir=os.path.join(transformers.TRANSFORMERS_CACHE,f'models--{hf_model_name.replace("/","--")}')ifos.path.exists(model_dir):print(f'CLEAN model cache files for {hf_model_name}')shutil.rmtree(model_dir)else:# not given the hf model name, remove the whole TRANSFORMERS_CACHEifos.path.exists(transformers.TRANSFORMERS_CACHE):print('CLEAN all TRANSFORMERS_CACHE')shutil.rmtree(transformers.TRANSFORMERS_CACHE)
[文档]defgenerate_dataset(self,data)->DJDataset:"""Generate dataset for a specific executor. Args: type (str, optional): "standalone" or "ray". Defaults to "standalone". """current_tag=getattr(self,'current_tag','standalone')ifcurrent_tag.startswith('standalone'):returnNestedDataset.from_list(data)elifcurrent_tag.startswith('ray'):# Only import Ray when neededray=LazyLoader('ray')fromdata_juicer.core.data.ray_datasetimportRayDatasetdataset=ray.data.from_items(data)returnRayDataset(dataset)else:raiseValueError('Unsupported type')
[文档]defrun_single_op(self,dataset:DJDataset,op,column_names):"""Run operator in the specific executor."""current_tag=getattr(self,'current_tag','standalone')dataset=dataset.process(op)ifcurrent_tag.startswith('standalone'):dataset=dataset.select_columns(column_names=column_names)returndataset.to_list()elifcurrent_tag.startswith('ray'):dataset=dataset.data.to_pandas().get(column_names)ifdatasetisNone:return[]returndataset.to_dict(orient='records')else:raiseValueError('Unsupported type')
[文档]defassertDatasetEqual(self,first,second):defconvert_record(rec):forkeyinrec.keys():# Convert incomparable `list` to comparable `tuple`ifisinstance(rec[key],numpy.ndarray)orisinstance(rec[key],list):rec[key]=tuple(rec[key])returnrecfirst=[convert_record(d)fordinfirst]second=[convert_record(d)fordinsecond]first=sorted(first,key=lambdax:tuple(sorted(x.items())))second=sorted(second,key=lambdax:tuple(sorted(x.items())))returnself.assertEqual(first,second)
# for partial unittest
[文档]defget_diff_files(prefix_filter=['data_juicer/','tests/']):"""Get git diff files in target dirs except the __init__.py files"""changed_files=subprocess.check_output(['git','diff','--name-only','--diff-filter=ACMRT','origin/main'],universal_newlines=True,).strip().split('\n')return[fforfinchanged_filesifany([f.startswith(prefix)forprefixinprefix_filter])andf.endswith('.py')andnotf.endswith('__init__.py')]
[文档]defget_partial_test_cases():diff_files=get_diff_files()test_files=[find_corresponding_test_file(file_path)forfile_pathindiff_files]ifNoneintest_files:# can't find corresponding test files for some changed files: run allreturnNonereturntest_files