Source code for data_juicer.utils.unittest_utils

import functools
import os
import shutil
import subprocess
import unittest

import numpy

from data_juicer import is_cuda_available
from data_juicer.core.data import DJDataset, NestedDataset
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import free_models

transformers = LazyLoader("transformers")

CLEAR_MODEL = False


[docs] def TEST_TAG(*tags): """Tags for test case. Currently, `standalone`, `ray` are supported. """ def decorator(func): setattr(func, "__test_tags__", tags) @functools.wraps(func) def wrapper(self, *args, **kwargs): # Save the original current_tag if it exists original_tag = getattr(self, "current_tag", "standalone") # Set the current_tag to the first tag if tags: self.current_tag = tags[0] try: # Run the test method return func(self, *args, **kwargs) finally: # Restore the original current_tag self.current_tag = original_tag return wrapper return decorator
[docs] def set_clear_model_flag(flag): global CLEAR_MODEL CLEAR_MODEL = flag if CLEAR_MODEL: print("CLEAR DOWNLOADED MODELS AFTER UNITTESTS.") else: print("KEEP DOWNLOADED MODELS AFTER UNITTESTS.")
[docs] class DataJuicerTestCaseBase(unittest.TestCase):
[docs] @classmethod def setUpClass(cls): # Set maxDiff for all test cases based on an environment variable max_diff = os.getenv("TEST_MAX_DIFF", "None") cls.maxDiff = None if max_diff == "None" else int(max_diff) import multiprocess cls.original_mp_method = multiprocess.get_start_method() if is_cuda_available(): multiprocess.set_start_method("spawn", force=True) # clear models in memory free_models()
[docs] @classmethod def tearDownClass(cls, hf_model_name=None) -> None: import multiprocess multiprocess.set_start_method(cls.original_mp_method, force=True) # clean the huggingface model cache files if not CLEAR_MODEL: pass elif hf_model_name: # given the hf model name, remove this model only model_dir = os.path.join(transformers.TRANSFORMERS_CACHE, f'models--{hf_model_name.replace("/", "--")}') if os.path.exists(model_dir): print(f"CLEAN model cache files for {hf_model_name}") shutil.rmtree(model_dir) else: # not given the hf model name, remove the whole TRANSFORMERS_CACHE if os.path.exists(transformers.TRANSFORMERS_CACHE): print("CLEAN all TRANSFORMERS_CACHE") shutil.rmtree(transformers.TRANSFORMERS_CACHE)
[docs] def tearDown(self) -> None: # clear models in memory free_models()
[docs] def generate_dataset(self, data) -> DJDataset: """Generate dataset for a specific executor. Args: type (str, optional): "standalone" or "ray". Defaults to "standalone". """ current_tag = getattr(self, "current_tag", "standalone") if current_tag.startswith("standalone"): return NestedDataset.from_list(data) elif current_tag.startswith("ray"): # Only import Ray when needed ray = LazyLoader("ray") from data_juicer.core.data.ray_dataset import RayDataset dataset = ray.data.from_items(data) return RayDataset(dataset) else: raise ValueError("Unsupported type")
[docs] def run_single_op(self, dataset: DJDataset, op, column_names): """Run operator in the specific executor.""" current_tag = getattr(self, "current_tag", "standalone") dataset = dataset.process(op) if current_tag.startswith("standalone"): dataset = dataset.select_columns(column_names=column_names) return dataset.to_list() elif current_tag.startswith("ray"): dataset = dataset.data.to_pandas().get(column_names) if dataset is None: return [] return dataset.to_dict(orient="records") else: raise ValueError("Unsupported type")
[docs] def assertDatasetEqual(self, first, second): def convert_record(rec): for key in rec.keys(): # Convert incomparable `list` to comparable `tuple` if isinstance(rec[key], numpy.ndarray) or isinstance(rec[key], list): rec[key] = tuple(rec[key]) return rec first = [convert_record(d) for d in first] second = [convert_record(d) for d in second] first = sorted(first, key=lambda x: tuple(sorted(x.items()))) second = sorted(second, key=lambda x: tuple(sorted(x.items()))) return self.assertEqual(first, second)
# for partial unittest
[docs] def get_diff_files(prefix_filter=["data_juicer/", "tests/"]): """Get git diff files in target dirs except the __init__.py files""" changed_files = ( subprocess.check_output( ["git", "diff", "--name-only", "--diff-filter=ACMRT", "origin/main"], universal_newlines=True, ) .strip() .split("\n") ) return [ f for f in changed_files if any([f.startswith(prefix) for prefix in prefix_filter]) and f.endswith(".py") and not f.endswith("__init__.py") ]
[docs] def find_corresponding_test_file(file_path): test_file = file_path.replace("data_juicer", "tests") basename = os.path.basename(test_file) dir = os.path.dirname(test_file) if not basename.startswith("test_") and basename != "run.py": basename = "test_" + basename test_file = os.path.join(dir, basename) if os.path.exists(test_file): return test_file else: return None
[docs] def get_partial_test_cases(): diff_files = get_diff_files() test_files = [find_corresponding_test_file(file_path) for file_path in diff_files] if None in test_files: # can't find corresponding test files for some changed files: run all return None return test_files