import os
from functools import partial
from loguru import logger
from data_juicer.utils.constant import Fields, HashKeys
from data_juicer.utils.webdataset_utils import reconstruct_custom_webdataset_format
[docs]
class RayExporter:
"""The Exporter class is used to export a ray dataset to files of specific
format."""
# TODO: support config for export, some export methods require additional args
_SUPPORTED_FORMATS = {
"json",
"jsonl",
"parquet",
"csv",
"tfrecords",
"webdataset",
"lance",
# 'images',
# 'numpy',
}
[docs]
def __init__(self, export_path, export_type=None, keep_stats_in_res_ds=True, keep_hashes_in_res_ds=False, **kwargs):
"""
Initialization method.
:param export_path: the path to export datasets.
:param keep_stats_in_res_ds: whether to keep stats in the result
dataset.
:param keep_hashes_in_res_ds: whether to keep hashes in the result
dataset.
"""
self.export_path = export_path
self.keep_stats_in_res_ds = keep_stats_in_res_ds
self.keep_hashes_in_res_ds = keep_hashes_in_res_ds
self.export_format = self._get_export_format(export_path) if export_type is None else export_type
if self.export_format not in self._SUPPORTED_FORMATS:
raise NotImplementedError(
f'export data format "{self.export_format}" is not supported '
f"for now. Only support {self._SUPPORTED_FORMATS}. Please check export_type or export_path."
)
self.export_extra_args = kwargs if kwargs is not None else {}
def _get_export_format(self, export_path):
"""
Get the suffix of export path and check if it's supported.
We only support ["jsonl", "json", "parquet"] for now.
:param export_path: the path to export datasets.
:return: the export data format.
"""
suffix = os.path.splitext(export_path)[-1].strip(".")
if not suffix:
logger.warning(
f'export_path "{export_path}" does not have a suffix. '
f'We will use "jsonl" as the default export type.'
)
suffix = "jsonl"
export_format = suffix
return export_format
def _export_impl(self, dataset, export_path, columns=None):
"""
Export a dataset to specific path.
:param dataset: the dataset to export.
:param export_path: the path to export the dataset.
:param columns: the columns to export.
:return:
"""
feature_fields = dataset.columns() if not columns else columns
removed_fields = []
if not self.keep_stats_in_res_ds:
extra_fields = {Fields.stats, Fields.meta}
removed_fields.extend(list(extra_fields.intersection(feature_fields)))
if not self.keep_hashes_in_res_ds:
extra_fields = {
HashKeys.hash,
HashKeys.minhash,
HashKeys.simhash,
HashKeys.imagehash,
HashKeys.videohash,
}
removed_fields.extend(list(extra_fields.intersection(feature_fields)))
if len(removed_fields):
dataset = dataset.drop_columns(removed_fields)
export_method = RayExporter._router()[self.export_format]
export_kwargs = {
"export_extra_args": self.export_extra_args,
"export_format": self.export_format,
}
return export_method(dataset, export_path, **export_kwargs)
[docs]
def export(self, dataset, columns=None):
"""
Export method for a dataset.
:param dataset: the dataset to export.
:param columns: the columns to export.
:return:
"""
self._export_impl(dataset, self.export_path, columns)
[docs]
@staticmethod
def write_json(dataset, export_path, **kwargs):
"""
Export method for json/jsonl target files.
:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
:param kwargs: extra arguments.
:return:
"""
return dataset.write_json(export_path, force_ascii=False)
[docs]
@staticmethod
def write_webdataset(dataset, export_path, **kwargs):
"""
Export method for webdataset target files.
:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
:param kwargs: extra arguments.
:return:
"""
from data_juicer.utils.webdataset_utils import _custom_default_encoder
# check if we need to reconstruct the customized WebDataset format
export_extra_args = kwargs.get("export_extra_args", {})
field_mapping = export_extra_args.get("field_mapping", {})
if len(field_mapping) > 0:
reconstruct_func = partial(reconstruct_custom_webdataset_format, field_mapping=field_mapping)
dataset = dataset.map(reconstruct_func)
return dataset.write_webdataset(export_path, encoder=_custom_default_encoder)
[docs]
@staticmethod
def write_others(dataset, export_path, **kwargs):
"""
Export method for other target files.
:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
:param kwargs: extra arguments.
:return:
"""
export_format = kwargs.get("export_format", "parquet")
return getattr(dataset, f"write_{export_format}")(export_path)
# suffix to export method
@staticmethod
def _router():
"""
A router from different suffixes to corresponding export methods.
:return: A dict router.
"""
return {
"jsonl": RayExporter.write_json,
"json": RayExporter.write_json,
"webdataset": RayExporter.write_webdataset,
"parquet": RayExporter.write_others,
"csv": RayExporter.write_others,
"tfrecords": RayExporter.write_others,
"lance": RayExporter.write_others,
}