Source code for data_juicer.core.data.ray_dataset

from __future__ import annotations

import os
import sys
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Union

import pyarrow
from jsonargparse import Namespace
from loguru import logger

from data_juicer.core.data import DJDataset
from data_juicer.core.data.schema import Schema
from data_juicer.ops import Deduplicator, Filter, Mapper
from data_juicer.ops.base_op import DEFAULT_BATCH_SIZE, TAGGING_OPS
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import is_remote_path
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.process_utils import calculate_np
from data_juicer.utils.resource_utils import cuda_device_count
from data_juicer.utils.webdataset_utils import _custom_default_decoder

ray = LazyLoader("ray")


[docs] def get_abs_path(path, dataset_dir): if is_remote_path(path): return path full_path = os.path.abspath(os.path.join(dataset_dir, path)) if os.path.exists(full_path): return full_path else: return path
[docs] def convert_to_absolute_paths(samples, dataset_dir, path_keys): samples = samples.to_pydict() for key in path_keys: for idx in range(len(samples[key])): paths = samples[key][idx] if isinstance(paths, str): samples[key][idx] = get_abs_path(paths, dataset_dir) elif isinstance(paths, list): samples[key][idx] = [get_abs_path(item, dataset_dir) for item in paths] return pyarrow.Table.from_pydict(samples)
# TODO: check path for nestdataset
[docs] def set_dataset_to_absolute_path(dataset, dataset_path, cfg): """ Set all the path in input data to absolute path. Checks dataset_dir and project_dir for valid paths. """ path_keys = [] columns = dataset.columns() for key in [ cfg.get("video_key", "videos"), cfg.get("image_key", "images"), cfg.get("audio_key", "audios"), ]: if key in columns: path_keys.append(key) if len(path_keys) > 0: dataset_dir = os.path.dirname(dataset_path) logger.info(f"dataset_dir: {dataset_dir}") dataset = dataset.map_batches( partial(convert_to_absolute_paths, dataset_dir=dataset_dir, path_keys=path_keys), batch_format="pyarrow", zero_copy_batch=True, batch_size=DEFAULT_BATCH_SIZE, ) return dataset
[docs] def preprocess_dataset(dataset: ray.data.Dataset, dataset_path, cfg) -> ray.data.Dataset: if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) return dataset
[docs] def get_num_gpus(op, op_proc): if not op.use_cuda(): return 0 proc_per_gpu = op_proc / cuda_device_count() return 1.0 / proc_per_gpu
[docs] def filter_batch(batch, filter_func): mask = pyarrow.array(filter_func(batch.to_pydict())) return batch.filter(mask)
[docs] class RayDataset(DJDataset):
[docs] def __init__(self, dataset: ray.data.Dataset, dataset_path: str = None, cfg: Optional[Namespace] = None) -> None: self.data = preprocess_dataset(dataset, dataset_path, cfg)
[docs] def schema(self) -> Schema: """Get dataset schema. Returns: Schema: Dataset schema containing column names and types """ if self.data is None or self.data.columns() is None: raise ValueError("Dataset is empty or not initialized") return Schema.from_ray_schema(self.data.schema())
[docs] def get(self, k: int) -> List[Dict[str, Any]]: """Get k rows from the dataset.""" if k < 0: raise ValueError(f"k must be non-negative, got {k}") if k == 0: return [] k = min(k, self.data.count()) return list(self.data.limit(k).take())
[docs] def get_column(self, column: str, k: Optional[int] = None) -> List[Any]: """Get column values from Ray dataset. Args: column: Name of the column to retrieve k: Optional number of rows to return. If None, returns all rows Returns: List of values from the specified column Raises: KeyError: If column doesn't exist ValueError: If k is negative """ if self.data is None or self.data.columns() is None or column not in self.data.columns(): raise KeyError(f"Column '{column}' not found in dataset") if k is not None: if k < 0: raise ValueError(f"k must be non-negative, got {k}") if k == 0: return [] k = min(k, self.data.count()) return [row[column] for row in self.data.limit(k).take()] return [row[column] for row in self.data.take()]
[docs] def process(self, operators, *, exporter=None, checkpointer=None, tracer=None) -> DJDataset: if operators is None: return self if not isinstance(operators, list): operators = [operators] for op in operators: self._run_single_op(op) return self
def _run_single_op(self, op): # TODO: optimize auto proc auto_parallel = False if op.num_proc: op_proc = op.num_proc else: auto_parallel = True op_proc = sys.maxsize auto_op_proc = calculate_np(op._name, op.mem_required, op.cpu_required, op.use_cuda(), op.gpu_required) op_proc = min(op_proc, auto_op_proc) # use ray default parallelism in cpu mode if op.num_proc is not specified if op.use_cuda() or not auto_parallel: logger.info(f"Op [{op._name}] running with number of procs:{op_proc}") num_gpus = op.gpu_required if op.gpu_required else get_num_gpus(op, op_proc) if op._name in TAGGING_OPS.modules and Fields.meta not in self.data.columns(): def process_batch_arrow(table: pyarrow.Table): new_column_data = [{} for _ in range(len(table))] new_table = table.append_column(Fields.meta, [new_column_data]) return new_table self.data = self.data.map_batches( process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE ) try: batch_size = getattr(op, "batch_size", 1) if op.is_batched_op() else 1 if isinstance(op, Mapper): if op.use_cuda(): op_kwargs = op._op_cfg[op._name] self.data = self.data.map_batches( op.__class__, fn_args=None, fn_kwargs=None, fn_constructor_args=None, fn_constructor_kwargs=op_kwargs, batch_size=batch_size, num_cpus=op.cpu_required, num_gpus=num_gpus, concurrency=op_proc, batch_format="pyarrow", ) else: self.data = self.data.map_batches( op.process, batch_size=batch_size, batch_format="pyarrow", num_cpus=op.cpu_required, concurrency=( None if auto_parallel else op_proc ), # use ray default parallelism in cpu mode if num_proc is not specified ) elif isinstance(op, Filter): columns = self.data.columns() if Fields.stats not in columns: def process_batch_arrow(table: pyarrow.Table): new_column_data = [{} for _ in range(len(table))] new_talbe = table.append_column(Fields.stats, [new_column_data]) return new_talbe self.data = self.data.map_batches( process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE ) if op.use_cuda(): op_kwargs = op._op_cfg[op._name] self.data = self.data.map_batches( op.__class__, fn_args=None, fn_kwargs=None, fn_constructor_args=None, fn_constructor_kwargs=op_kwargs, batch_size=batch_size, num_cpus=op.cpu_required, num_gpus=num_gpus, concurrency=op_proc, batch_format="pyarrow", ) else: self.data = self.data.map_batches( op.compute_stats, batch_size=batch_size, batch_format="pyarrow", num_cpus=op.cpu_required, concurrency=( None if auto_parallel else op_proc ), # use ray default parallelism in cpu mode if num_proc is not specified ) if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) if op.is_batched_op(): # The core computation have been done in compute_stats, # and the filter process only performs simple filtering. # cpu and parallelism are not set here self.data = self.data.map_batches( partial(filter_batch, filter_func=op.process), batch_format="pyarrow", zero_copy_batch=True, batch_size=DEFAULT_BATCH_SIZE, ) else: self.data = self.data.filter(op.process) elif isinstance(op, Deduplicator): self.data = op.run(self.data) else: logger.error("Ray executor only support Filter, Mapper and Deduplicator OPs for now") raise NotImplementedError except: # noqa: E722 logger.error(f"An error occurred during Op [{op._name}].") import traceback traceback.print_exc() exit(1)
[docs] @classmethod def read(cls, data_format: str, paths: Union[str, List[str]]) -> RayDataset: if data_format in {"json", "jsonl"}: return RayDataset.read_json(paths) elif data_format == "webdataset": return RayDataset.read_webdataset(paths) elif data_format in { "parquet", "images", "parquet_bulk", "csv", "text", "avro", "numpy", "tfrecords", "binary_files", "lance", }: return getattr(ray.data, f"read_{data_format}")(paths)
[docs] @classmethod def read_json(cls, paths: Union[str, List[str]]) -> RayDataset: # Note: a temp solution for reading json stream # TODO: replace with ray.data.read_json_stream once it is available import pyarrow.json as js try: js.open_json return read_json_stream(paths) except AttributeError: return ray.data.read_json(paths)
[docs] @classmethod def read_webdataset(cls, paths: Union[str, List[str]]) -> RayDataset: return ray.data.read_webdataset(paths, decoder=partial(_custom_default_decoder, format="PIL"))
[docs] def to_list(self) -> list: return self.data.to_pandas().to_dict(orient="records")
[docs] class JSONStreamDatasource(ray.data.read_api.JSONDatasource): """ A temp Datasource for reading json stream. Note: Depends on a customized `pyarrow` with `open_json` method. """ def _read_stream(self, f: "pyarrow.NativeFile", path: str): from pyarrow.json import open_json try: reader = open_json( f, read_options=self.read_options, **self.arrow_json_args, ) schema = None while True: try: batch = reader.read_next_batch() table = pyarrow.Table.from_batches([batch], schema=schema) if schema is None: schema = table.schema yield table except StopIteration: return except pyarrow.lib.ArrowInvalid as e: raise ValueError(f"Failed to read JSON file: {path}.") from e
[docs] def read_json_stream( paths: Union[str, List[str]], *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = -1, ray_remote_args: Dict[str, Any] = None, arrow_open_stream_args: Optional[Dict[str, Any]] = None, meta_provider=None, partition_filter=None, partitioning=ray.data.read_api.Partitioning("hive"), include_paths: bool = False, ignore_missing_paths: bool = False, shuffle: Union[Literal["files"], None] = None, file_extensions: Optional[List[str]] = ["json", "jsonl"], concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, **arrow_json_args, ) -> ray.data.Dataset: if meta_provider is None: meta_provider = ray.data.read_api.DefaultFileMetadataProvider() datasource = JSONStreamDatasource( paths, arrow_json_args=arrow_json_args, filesystem=filesystem, open_stream_args=arrow_open_stream_args, meta_provider=meta_provider, partition_filter=partition_filter, partitioning=partitioning, ignore_missing_paths=ignore_missing_paths, shuffle=shuffle, include_paths=include_paths, file_extensions=file_extensions, ) return ray.data.read_datasource( datasource, parallelism=parallelism, ray_remote_args=ray_remote_args, concurrency=concurrency, override_num_blocks=override_num_blocks, )