import os
from functools import partial
from loguru import logger
from data_juicer.utils.constant import Fields, HashKeys
from data_juicer.utils.file_utils import Sizes, byte_size_to_size_str
from data_juicer.utils.model_utils import filter_arguments
from data_juicer.utils.webdataset_utils import reconstruct_custom_webdataset_format
[docs]
class RayExporter:
"""The Exporter class is used to export a ray dataset to files of specific
format."""
# TODO: support config for export, some export methods require additional args
_SUPPORTED_FORMATS = {
"json",
"jsonl",
"parquet",
"csv",
"tfrecords",
"webdataset",
"lance",
# 'images',
# 'numpy',
}
[docs]
def __init__(
self,
export_path,
export_type=None,
export_shard_size=0,
keep_stats_in_res_ds=True,
keep_hashes_in_res_ds=False,
**kwargs,
):
"""
Initialization method.
:param export_path: the path to export datasets.
:param export_type: the format type of the exported datasets.
:param export_shard_size: the approximate size of each shard of exported
dataset. In default, it's 0, which means export the dataset in the default setting of ray.
:param keep_stats_in_res_ds: whether to keep stats in the result
dataset.
:param keep_hashes_in_res_ds: whether to keep hashes in the result
dataset.
"""
self.export_path = export_path
self.export_shard_size = export_shard_size
self.keep_stats_in_res_ds = keep_stats_in_res_ds
self.keep_hashes_in_res_ds = keep_hashes_in_res_ds
self.export_format = self._get_export_format(export_path) if export_type is None else export_type
if self.export_format not in self._SUPPORTED_FORMATS:
raise NotImplementedError(
f'export data format "{self.export_format}" is not supported '
f"for now. Only support {self._SUPPORTED_FORMATS}. Please check export_type or export_path."
)
self.export_extra_args = kwargs if kwargs is not None else {}
self.max_shard_size_str = ""
# get the string format of shard size
self.max_shard_size_str = byte_size_to_size_str(self.export_shard_size)
# we recommend users to set a shard size between MiB and TiB.
if 0 < self.export_shard_size < Sizes.MiB:
logger.warning(
f"The export_shard_size [{self.max_shard_size_str}]"
f" is less than 1MiB. If the result dataset is too "
f"large, there might be too many shard files to "
f"generate."
)
if self.export_shard_size >= Sizes.TiB:
logger.warning(
f"The export_shard_size [{self.max_shard_size_str}]"
f" is larger than 1TiB. It might generate large "
f"single shard file and make loading and exporting "
f"slower."
)
def _get_export_format(self, export_path):
"""
Get the suffix of export path and check if it's supported.
We only support ["jsonl", "json", "parquet"] for now.
:param export_path: the path to export datasets.
:return: the export data format.
"""
suffix = os.path.splitext(export_path)[-1].strip(".")
if not suffix:
logger.warning(
f'export_path "{export_path}" does not have a suffix. '
f'We will use "jsonl" as the default export type.'
)
suffix = "jsonl"
export_format = suffix
return export_format
def _export_impl(self, dataset, export_path, columns=None):
"""
Export a dataset to specific path.
:param dataset: the dataset to export.
:param export_path: the path to export the dataset.
:param columns: the columns to export.
:return:
"""
feature_fields = dataset.columns() if not columns else columns
removed_fields = []
if not self.keep_stats_in_res_ds:
extra_fields = {Fields.stats, Fields.meta}
removed_fields.extend(list(extra_fields.intersection(feature_fields)))
if not self.keep_hashes_in_res_ds:
extra_fields = {
HashKeys.hash,
HashKeys.minhash,
HashKeys.simhash,
HashKeys.imagehash,
HashKeys.videohash,
}
removed_fields.extend(list(extra_fields.intersection(feature_fields)))
if len(removed_fields):
dataset = dataset.drop_columns(removed_fields)
export_method = RayExporter._router()[self.export_format]
export_kwargs = {
"export_extra_args": self.export_extra_args,
"export_format": self.export_format,
}
if self.export_shard_size > 0:
# compute the min_rows_per_file for export methods
dataset_nbytes = dataset.data.size_bytes()
dataset_num_rows = dataset.data.count()
num_shards = int(dataset_nbytes / self.export_shard_size) + 1
num_shards = min(num_shards, dataset_num_rows)
rows_per_file = int(dataset_num_rows / num_shards)
export_kwargs["export_extra_args"]["min_rows_per_file"] = rows_per_file
return export_method(dataset, export_path, **export_kwargs)
[docs]
def export(self, dataset, columns=None):
"""
Export method for a dataset.
:param dataset: the dataset to export.
:param columns: the columns to export.
:return:
"""
self._export_impl(dataset, self.export_path, columns)
[docs]
@staticmethod
def write_json(dataset, export_path, **kwargs):
"""
Export method for json/jsonl target files.
:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
:param kwargs: extra arguments.
:return:
"""
export_extra_args = kwargs.get("export_extra_args", {})
filtered_kwargs = filter_arguments(dataset.write_json, export_extra_args)
return dataset.write_json(export_path, force_ascii=False, **filtered_kwargs)
[docs]
@staticmethod
def write_webdataset(dataset, export_path, **kwargs):
"""
Export method for webdataset target files.
:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
:param kwargs: extra arguments.
:return:
"""
from data_juicer.utils.webdataset_utils import _custom_default_encoder
# check if we need to reconstruct the customized WebDataset format
export_extra_args = kwargs.get("export_extra_args", {})
field_mapping = export_extra_args.get("field_mapping", {})
if len(field_mapping) > 0:
reconstruct_func = partial(reconstruct_custom_webdataset_format, field_mapping=field_mapping)
dataset = dataset.map(reconstruct_func)
filtered_kwargs = filter_arguments(dataset.write_webdataset, export_extra_args)
return dataset.write_webdataset(export_path, encoder=_custom_default_encoder, **filtered_kwargs)
[docs]
@staticmethod
def write_others(dataset, export_path, **kwargs):
"""
Export method for other target files.
:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
:param kwargs: extra arguments.
:return:
"""
export_format = kwargs.get("export_format", "parquet")
write_method = getattr(dataset, f"write_{export_format}")
export_extra_args = kwargs.get("export_extra_args", {})
filtered_kwargs = filter_arguments(write_method, export_extra_args)
return getattr(dataset, f"write_{export_format}")(export_path, **filtered_kwargs)
# suffix to export method
@staticmethod
def _router():
"""
A router from different suffixes to corresponding export methods.
:return: A dict router.
"""
return {
"jsonl": RayExporter.write_json,
"json": RayExporter.write_json,
"webdataset": RayExporter.write_webdataset,
"parquet": RayExporter.write_others,
"csv": RayExporter.write_others,
"tfrecords": RayExporter.write_others,
"lance": RayExporter.write_others,
}