Source code for data_juicer.core.ray_exporter

import os
from functools import partial

from loguru import logger

from data_juicer.utils.constant import Fields, HashKeys
from data_juicer.utils.webdataset_utils import reconstruct_custom_webdataset_format


[docs] class RayExporter: """The Exporter class is used to export a ray dataset to files of specific format.""" # TODO: support config for export, some export methods require additional args _SUPPORTED_FORMATS = { "json", "jsonl", "parquet", "csv", "tfrecords", "webdataset", "lance", # 'images', # 'numpy', }
[docs] def __init__(self, export_path, export_type=None, keep_stats_in_res_ds=True, keep_hashes_in_res_ds=False, **kwargs): """ Initialization method. :param export_path: the path to export datasets. :param keep_stats_in_res_ds: whether to keep stats in the result dataset. :param keep_hashes_in_res_ds: whether to keep hashes in the result dataset. """ self.export_path = export_path self.keep_stats_in_res_ds = keep_stats_in_res_ds self.keep_hashes_in_res_ds = keep_hashes_in_res_ds self.export_format = self._get_export_format(export_path) if export_type is None else export_type if self.export_format not in self._SUPPORTED_FORMATS: raise NotImplementedError( f'export data format "{self.export_format}" is not supported ' f"for now. Only support {self._SUPPORTED_FORMATS}. Please check export_type or export_path." ) self.export_extra_args = kwargs if kwargs is not None else {}
def _get_export_format(self, export_path): """ Get the suffix of export path and check if it's supported. We only support ["jsonl", "json", "parquet"] for now. :param export_path: the path to export datasets. :return: the export data format. """ suffix = os.path.splitext(export_path)[-1].strip(".") if not suffix: logger.warning( f'export_path "{export_path}" does not have a suffix. ' f'We will use "jsonl" as the default export type.' ) suffix = "jsonl" export_format = suffix return export_format def _export_impl(self, dataset, export_path, columns=None): """ Export a dataset to specific path. :param dataset: the dataset to export. :param export_path: the path to export the dataset. :param columns: the columns to export. :return: """ feature_fields = dataset.columns() if not columns else columns removed_fields = [] if not self.keep_stats_in_res_ds: extra_fields = {Fields.stats, Fields.meta} removed_fields.extend(list(extra_fields.intersection(feature_fields))) if not self.keep_hashes_in_res_ds: extra_fields = { HashKeys.hash, HashKeys.minhash, HashKeys.simhash, HashKeys.imagehash, HashKeys.videohash, } removed_fields.extend(list(extra_fields.intersection(feature_fields))) if len(removed_fields): dataset = dataset.drop_columns(removed_fields) export_method = RayExporter._router()[self.export_format] export_kwargs = { "export_extra_args": self.export_extra_args, "export_format": self.export_format, } return export_method(dataset, export_path, **export_kwargs)
[docs] def export(self, dataset, columns=None): """ Export method for a dataset. :param dataset: the dataset to export. :param columns: the columns to export. :return: """ self._export_impl(dataset, self.export_path, columns)
[docs] @staticmethod def write_json(dataset, export_path, **kwargs): """ Export method for json/jsonl target files. :param dataset: the dataset to export. :param export_path: the path to store the exported dataset. :param kwargs: extra arguments. :return: """ return dataset.write_json(export_path, force_ascii=False)
[docs] @staticmethod def write_webdataset(dataset, export_path, **kwargs): """ Export method for webdataset target files. :param dataset: the dataset to export. :param export_path: the path to store the exported dataset. :param kwargs: extra arguments. :return: """ from data_juicer.utils.webdataset_utils import _custom_default_encoder # check if we need to reconstruct the customized WebDataset format export_extra_args = kwargs.get("export_extra_args", {}) field_mapping = export_extra_args.get("field_mapping", {}) if len(field_mapping) > 0: reconstruct_func = partial(reconstruct_custom_webdataset_format, field_mapping=field_mapping) dataset = dataset.map(reconstruct_func) return dataset.write_webdataset(export_path, encoder=_custom_default_encoder)
[docs] @staticmethod def write_others(dataset, export_path, **kwargs): """ Export method for other target files. :param dataset: the dataset to export. :param export_path: the path to store the exported dataset. :param kwargs: extra arguments. :return: """ export_format = kwargs.get("export_format", "parquet") return getattr(dataset, f"write_{export_format}")(export_path)
# suffix to export method @staticmethod def _router(): """ A router from different suffixes to corresponding export methods. :return: A dict router. """ return { "jsonl": RayExporter.write_json, "json": RayExporter.write_json, "webdataset": RayExporter.write_webdataset, "parquet": RayExporter.write_others, "csv": RayExporter.write_others, "tfrecords": RayExporter.write_others, "lance": RayExporter.write_others, }