import os
import re
import shutil
from abc import ABC, abstractmethod
from multiprocessing import Pool
from pathlib import Path
from typing import Dict, List, Optional, Type, Union
from datasets import Dataset
from datasets.utils.extract import Extractor as HF_Extractor
from datasets.utils.filelock import FileLock as HF_FileLock
from loguru import logger
from data_juicer.utils import cache_utils
[docs]
class FileLock(HF_FileLock):
"""
File lock for compresssion or decompression, and
remove lock file automatically.
"""
def _release(self):
super()._release()
try:
# logger.debug(f'Remove {self._lock_file}')
os.remove(self._lock_file)
# The file is already deleted and that's what we want.
except OSError:
pass
return None
[docs]
class BaseCompressor(ABC):
"""
Base class that compresses a file.
"""
[docs]
@staticmethod
@abstractmethod
def compress(input_path: Union[Path, str], output_path: Union[Path, str]):
"""
Compress input file and save to output file.
:param input_path: path to uncompressed file.
:param output_path: path to compressed file.
"""
...
[docs]
class ZstdCompressor(BaseCompressor):
"""
This class compresses a file using the `zstd` algorithm.
"""
[docs]
@staticmethod
def compress(input_path: Union[Path, str], output_path: Union[Path, str]):
"""
Compress input file and save to output file.
:param input_path: path to uncompressed file.
:param output_path: path to compressed file.
"""
import zstandard as zstd
cctx = zstd.ZstdCompressor()
with open(input_path, 'rb') as ifh, open(output_path, 'wb') as ofh:
cctx.copy_stream(ifh, ofh)
[docs]
class Lz4Compressor(BaseCompressor):
"""
This class compresses a file using the `lz4` algorithm.
"""
[docs]
@staticmethod
def compress(input_path: Union[Path, str], output_path: Union[Path, str]):
"""
Compress a input file and save to output file.
:param input_path: path to uncompressed file.
:param output_path: path to compressed file.
"""
import lz4.frame
with open(input_path, 'rb') as input_file:
with lz4.frame.open(output_path, 'wb') as compressed_file:
shutil.copyfileobj(input_file, compressed_file)
[docs]
class GzipCompressor(BaseCompressor):
"""
This class compresses a file using the `gzip` algorithm.
"""
[docs]
@staticmethod
def compress(input_path: Union[Path, str], output_path: Union[Path, str]):
"""
Compress input file and save to output file.
:param input_path: path to uncompressed file.
:param output_path: path to compressed file.
"""
import gzip
with open(input_path, 'rb') as input_file:
with gzip.open(output_path, 'wb') as compressed_file:
shutil.copyfileobj(input_file, compressed_file)
[docs]
class Compressor:
"""
This class that contains multiple compressors.
"""
compressors: Dict[str, Type[BaseCompressor]] = {
'gzip': GzipCompressor,
# "zip": ZipCompressor,
# "xz": XzCompressor,
# "rar": RarCompressor,
'zstd': ZstdCompressor,
# "bz2": Bzip2Compressor,
# "7z": SevenZipCompressor,
'lz4': Lz4Compressor,
}
[docs]
@classmethod
def compress(
cls,
input_path: Union[Path, str],
output_path: Union[Path, str],
compressor_format: str,
):
"""
Compress input file and save to output file.
:param input_path: path to uncompressed file.
:param output_path: path to compressed file.
:param compressor_format: compression format,
see supported algorithm in `compressors`.
"""
assert compressor_format in cls.compressors.keys()
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Prevent parallel extractions
lock_path = str(Path(output_path).with_suffix('.lock'))
with FileLock(lock_path):
shutil.rmtree(output_path, ignore_errors=True)
compressor = cls.compressors[compressor_format]
compressor.compress(input_path, output_path)
[docs]
class CompressManager:
"""
This class is used to compress or decompress a input file
using compression format algorithms.
"""
[docs]
def __init__(self, compressor_format: str = 'zstd'):
"""
Initialization method.
:param compressor_format: compression format algorithms,
default `zstd`.
"""
assert compressor_format in Compressor.compressors.keys()
self.compressor_format = compressor_format
self.compressor = Compressor
self.extractor = Extractor
[docs]
def compress(
self,
input_path: Union[Path, str],
output_path: Union[Path, str],
):
"""
Compress input file and save to output file.
:param input_path: path to uncompressed file.
:param output_path: path to compressed file.
"""
self.compressor.compress(input_path, output_path,
self.compressor_format)
[docs]
def decompress(
self,
input_path: Union[Path, str],
output_path: Union[Path, str],
):
"""
Decompress input file and save to output file.
:param input_path: path to compressed file.
:param output_path: path to uncompressed file.
"""
self.extractor.extract(input_path, output_path, self.compressor_format)
[docs]
class CacheCompressManager:
"""
This class is used to compress or decompress huggingface cache files
using compression format algorithms.
"""
[docs]
def __init__(self, compressor_format: str = 'zstd'):
"""
Initialization method.
:param compressor_format: compression format algorithms,
default `zstd`.
"""
self.compressor_format = compressor_format
self.compressor_extension = '.' + compressor_format
self.compress_manager = CompressManager(
compressor_format=compressor_format)
self.pattern = re.compile(r'_\d{5}_of_')
def _get_raw_filename(self, filename: Union[Path, str]):
"""
Get a uncompressed file name from a compressed file.
:param filename: path to compressed file.
:return: path to uncompressed file.
"""
assert filename.endswith(self.compressor_format)
return str(filename)[:-len(self.compressor_extension)]
def _get_compressed_filename(self, filename: Union[Path, str]):
"""
Get a compressed file name from a uncompressed file.
:param filename: path to uncompressed file.
:return: path to compressed file.
"""
return str(filename) + self.compressor_extension
def _get_cache_directory(self, ds):
"""
Get dataset cache directory.
:param ds: input dataset.
:return: dataset cache directory.
"""
current_cache_files = [
os.path.abspath(cache_file['filename'])
for cache_file in ds.cache_files
]
if not current_cache_files:
return None
cache_directory = os.path.dirname(current_cache_files[0])
return cache_directory
def _get_cache_file_names(self,
cache_directory: str,
fingerprints: Union[str, List[str]] = None,
extension='.arrow'):
"""
Get all cache files in the dataset cache directory with fingerprints,
which ends with specified extension.
:param cache_directory: dataset cache directory.
:param fingerprints: fingerprints of cache files. String or List are
accepted. If `None`, we will find all cache files which starts with
`cache-` and ends with specified extension.
:param extension: extension of cache files, default `.arrow`
:return: list of file names
"""
if cache_directory is None:
return []
if fingerprints is None:
fingerprints = ['']
if isinstance(fingerprints, str):
fingerprints = [fingerprints]
files: List[str] = os.listdir(cache_directory)
f_names = []
for f_name in files:
for fingerprint in fingerprints:
if f_name.startswith(f'cache-{fingerprint}') \
and f_name.endswith(extension):
f_names.append(f_name)
return f_names
[docs]
def compress(self,
prev_ds: Dataset,
this_ds: Dataset = None,
num_proc: int = 1):
"""
Compress cache files with fingerprint in dataset cache directory.
:param prev_ds: previous dataset whose cache files need to be
compressed here.
:param this_ds: Current dataset that is computed from the previous
dataset. There might be overlaps between cache files of them, so we
must not compress cache files that will be used again in the
current dataset. If it's None, it means all cache files of previous
dataset should be compressed.
:param num_proc: number of processes to compress cache files.
"""
# remove cache files from the list of cache files to be compressed
prev_cache_names = [item['filename'] for item in prev_ds.cache_files]
this_cache_names = [item['filename'] for item in this_ds.cache_files] \
if this_ds else []
caches_to_compress = list(
set(prev_cache_names) - set(this_cache_names))
files_to_remove = []
files_printed = set()
if num_proc > 1:
pool = Pool(num_proc)
for full_name in caches_to_compress:
# ignore the cache file of the original dataset and only consider
# the cache files of following OPs
if not os.path.basename(full_name).startswith('cache-'):
continue
# If there are no specified cache files, just skip
if not os.path.exists(full_name):
continue
compress_filename = self._get_compressed_filename(full_name)
formatted_cache_name = self.format_cache_file_name(
compress_filename)
if not os.path.exists(compress_filename):
if formatted_cache_name not in files_printed:
logger.info(
f'Compressing cache file to {formatted_cache_name}')
if num_proc > 1:
pool.apply_async(self.compress_manager.compress,
args=(
full_name,
compress_filename,
))
else:
self.compress_manager.compress(full_name,
compress_filename)
else:
if formatted_cache_name not in files_printed:
logger.debug(
f'Found compressed cache file {formatted_cache_name}')
files_printed.add(formatted_cache_name)
files_to_remove.append(full_name)
if num_proc > 1:
pool.close()
pool.join()
# clean up raw cache file
for file_path in files_to_remove:
logger.debug(f'Removing cache file {file_path}')
os.remove(file_path)
[docs]
def decompress(self,
ds: Dataset,
fingerprints: Union[str, List[str]] = None,
num_proc: int = 1):
"""
Decompress compressed cache files with fingerprint in
dataset cache directory.
:param ds: input dataset.
:param fingerprints: fingerprintd of cache files. String or List are
accepted. If `None`, we will find all cache files which starts with
`cache-` and ends with compression format.
:param num_proc: number of processes to decompress cache files.
"""
cache_directory = self._get_cache_directory(ds)
if cache_directory is None:
return
# find compressed cache files with given fingerprints
f_names = self._get_cache_file_names(
cache_directory=cache_directory,
fingerprints=fingerprints,
extension=self.compressor_extension)
files_printed = set()
if num_proc > 1:
pool = Pool(num_proc)
for f_name in f_names:
full_name = os.path.abspath(os.path.join(cache_directory, f_name))
raw_filename = self._get_raw_filename(full_name)
formatted_cache_name = self.format_cache_file_name(raw_filename)
if not os.path.exists(raw_filename):
if formatted_cache_name not in files_printed:
logger.info(f'Decompressing cache file to '
f'{formatted_cache_name}')
files_printed.add(formatted_cache_name)
if num_proc > 1:
pool.apply_async(self.compress_manager.decompress,
args=(
full_name,
raw_filename,
))
else:
self.compress_manager.decompress(full_name, raw_filename)
else:
if formatted_cache_name not in files_printed:
logger.debug(f'Found uncompressed cache files '
f'{formatted_cache_name}')
if num_proc > 1:
pool.close()
pool.join()
[docs]
def cleanup_cache_files(self, ds):
"""
Clean up all compressed cache files in dataset cache directory,
which starts with `cache-` and ends with compression format
:param ds: input dataset.
"""
cache_directory = self._get_cache_directory(ds)
if cache_directory is None:
return
f_names = self._get_cache_file_names(
cache_directory=cache_directory,
extension=self.compressor_extension)
files_printed = set()
for f_name in f_names:
full_name = os.path.abspath(os.path.join(cache_directory, f_name))
formatted_cache_name = self.format_cache_file_name(full_name)
if formatted_cache_name not in files_printed:
logger.debug(f'Clean up cache file {formatted_cache_name}')
files_printed.add(formatted_cache_name)
os.remove(full_name)
return len(f_names)
[docs]
class CompressionOff:
"""Define a range that turn off the cache compression temporarily."""
def __enter__(self):
"""
Record the original cache compression method and turn it off.
"""
from . import cache_utils
self.original_cache_compress = cache_utils.CACHE_COMPRESS
cache_utils.CACHE_COMPRESS = None
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Restore the original cache compression method.
"""
from . import cache_utils
cache_utils.CACHE_COMPRESS = self.original_cache_compress
[docs]
def compress(prev_ds, this_ds=None, num_proc=1):
if cache_utils.CACHE_COMPRESS:
CacheCompressManager(cache_utils.CACHE_COMPRESS).compress(
prev_ds, this_ds, num_proc)
[docs]
def decompress(ds, fingerprints=None, num_proc=1):
if cache_utils.CACHE_COMPRESS:
CacheCompressManager(cache_utils.CACHE_COMPRESS).decompress(
ds, fingerprints, num_proc)
[docs]
def cleanup_compressed_cache_files(ds):
CacheCompressManager().cleanup_cache_files(ds)