Source code for data_juicer.core.data.dataset_builder

import os
import shlex
from argparse import Namespace
from typing import List, Tuple

import numpy as np
from datasets import concatenate_datasets
from loguru import logger

from data_juicer.core.data import DJDataset, NestedDataset
from data_juicer.core.data.config_validator import ConfigValidationError
from data_juicer.core.data.data_validator import DataValidatorRegistry
from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry
from data_juicer.utils.file_utils import is_absolute_path
from data_juicer.utils.sample import random_sample


[docs] class DatasetBuilder(object): """ DatasetBuilder is a class that builds a dataset from a configuration. """
[docs] def __init__(self, cfg: Namespace, executor_type: str = 'default'): self.use_generated_dataset_config = False self.cfg = cfg self.executor_type = executor_type self.require_dataset_arg = False # priority: generated_dataset_config > dataset_path > dataset if hasattr( cfg, 'generated_dataset_config') and cfg.generated_dataset_config: self.use_generated_dataset_config = True self.generated_dataset_config = cfg.generated_dataset_config return elif hasattr(cfg, 'dataset_path') and cfg.dataset_path: logger.info(f'found dataset_path setting: {cfg.dataset_path}') ds_configs = rewrite_cli_datapath(cfg.dataset_path) elif hasattr(cfg, 'dataset') and cfg.dataset: logger.info(f'found dataset setting: {cfg.dataset}') ds_configs = cfg.dataset else: logger.warning( 'No dataset setting found in configurations. Will ' 'check the dataset argument before loading dataset.') self.require_dataset_arg = True return # validate dataset config for type constraints # TODO other constraints; ray dataset only supports local, etc. if not isinstance(ds_configs, dict): raise ConfigValidationError( 'Dataset config should be a dictionary') if 'configs' not in ds_configs: raise ConfigValidationError( 'Dataset config should have a "configs" key') if (not isinstance(ds_configs['configs'], list) or len(ds_configs['configs']) == 0): raise ConfigValidationError( 'Dataset config "configs" should be a non-empty list') if ('max_sample_num' in ds_configs and (not isinstance(ds_configs['max_sample_num'], int) or ds_configs['max_sample_num'] <= 0)): raise ConfigValidationError( 'Dataset config "max_sample_num" should be a positive integer') for ds_config in ds_configs['configs']: if not isinstance(ds_config, dict): raise ConfigValidationError( 'Dataset configs should be dictionaries') types = [ ds_config.get('type', None) for ds_config in ds_configs['configs'] ] if len(set(types)) > 1: raise ConfigValidationError( 'Mixture of diff types (LOCAL/REMOTE/...) are not supported') if types[0] == 'remote' and len(ds_configs['configs']) > 1: raise ConfigValidationError( 'Multiple remote datasets are not supported') # initialize the data load strategies self.load_strategies = [] for ds_config in ds_configs['configs']: # initialize data loading strategy data_type = ds_config.get('type', None) data_source = ds_config.get('source', None) stra = DataLoadStrategyRegistry.get_strategy_class( self.executor_type, data_type, data_source)(ds_config, cfg=self.cfg) if stra is None: raise ValueError(f'No data load strategy found for' f' {data_type} {data_source}') self.load_strategies.append(stra) # failed to initialize any load strategy if not self.load_strategies: logger.error(f'No data load strategies found for {ds_configs}') raise ConfigValidationError('No data load strategies found') # initialzie the sample numbers self.max_sample_num = ds_configs.get('max_sample_num', None) # get weights and sample numbers if self.max_sample_num: self.weights = [stra.weight for stra in self.load_strategies] self.sample_numbers = get_sample_numbers(self.weights, self.max_sample_num) else: self.weights = [1.0 for stra in self.load_strategies] self.sample_numbers = [None for stra in self.load_strategies] # initialize data validators self.validators = [] if hasattr(cfg, 'validators'): for validator_config in cfg.validators: if 'type' not in validator_config: raise ValueError('Validator config must have a "type" key') validator_type = validator_config['type'] validator_cls = DataValidatorRegistry.get_validator( validator_type) if validator_cls: self.validators.append(validator_cls(validator_config)) else: raise ValueError( f'No data validator found for {validator_type}')
[docs] def load_dataset(self, **kwargs) -> DJDataset: if self.require_dataset_arg: # should not get into this method raise ValueError( 'Unable to load dataset; should have one of ' 'generated_dataset_config, dataset_path, or dataset ' 'in configurations, or pass the `dataset` object through `run`' ' method') # if generated_dataset_config present, prioritize if self.use_generated_dataset_config: return DatasetBuilder.load_dataset_by_generated_config( self.generated_dataset_config) _datasets = [] # load datasets with sample numbers for stra, weight, sample_num in zip(self.load_strategies, self.weights, self.sample_numbers): # load dataset with its load strategy dataset = stra.load_data(**kwargs) # do data validation for validator in self.validators: validator.validate(dataset) # do data sampling, if necessary if self.max_sample_num: dataset = random_sample(dataset, weight, sample_num) _datasets.append(dataset) # handle data mixture if self.executor_type == 'default': return NestedDataset(concatenate_datasets(_datasets)) elif self.executor_type == 'ray': # TODO: support multiple datasets and mixing for ray assert len( _datasets) == 1, 'Ray setup only supports one dataset now' return _datasets[0]
[docs] @classmethod def load_dataset_by_generated_config(cls, generated_dataset_config): """ load dataset by generated config """ assert isinstance(generated_dataset_config, dict) and 'type' in generated_dataset_config args = generated_dataset_config.copy() # TODO finish the auto local dataset part obj_name = args.pop('type') from data_juicer.format.formatter import FORMATTERS dataset = FORMATTERS.modules[obj_name](**args).load_dataset() return dataset
[docs] def rewrite_cli_datapath(dataset_path, max_sample_num=None) -> List: """ rewrite the dataset_path from CLI into proper dataset config format that is compatible with YAML config style; retrofitting CLI input of local files and huggingface path :param dataset_path: a dataset file or a dataset dir or a list of them, e.g. `<w1> ds1.jsonl <w2> ds2_dir <w3> ds3_file.json` :param max_sample_num: the maximum number of samples to load :return: list of dataset configs """ paths, weights = parse_cli_datapath(dataset_path) ret = ({ 'configs': [], 'max_sample_num': max_sample_num } if max_sample_num else { 'configs': [] }) for p, w in zip(paths, weights): if os.path.isdir(p) or os.path.isfile(p): # local files ret['configs'].append({'type': 'local', 'path': p, 'weight': w}) elif (not is_absolute_path(p) and not p.startswith('.') and p.count('/') <= 1): # remote huggingface ret['configs'].append({ 'type': 'huggingface', 'path': p, 'split': 'train' }) else: # raise ValueError( f'Unable to load the dataset from [{dataset_path}]. ' f'Data-Juicer CLI mode only supports local files ' f'w or w/o weights, or huggingface path') return ret
[docs] def parse_cli_datapath(dataset_path) -> Tuple[List[str], List[float]]: """ Split every dataset path and its weight. :param dataset_path: a dataset file or a dataset dir or a list of them, e.g. `<w1> ds1.jsonl <w2> ds2_dir <w3> ds3_file.json` :return: list of dataset path and list of weights """ # Handle empty input if not dataset_path or not dataset_path.strip(): return [], [] # Use shlex to properly handle quoted strings try: tokens = shlex.split(dataset_path) except ValueError as e: raise ValueError(f'Invalid dataset path format: {e}') prefixes = [] weights = [] for i in range(len(tokens)): try: value = max(float(tokens[i]), 0.0) weights.append(value) except: # noqa: E722 value = tokens[i].strip() # if not set weight, use 1.0 as default if i == 0 or len(weights) == len(prefixes): weights.append(1.0) prefixes.append(value) return prefixes, weights
[docs] def get_sample_numbers(weights, max_sample_num): sample_numbers = [0] * len(weights) # Normalize weights weights = np.array(weights, dtype=np.float64) sum_weights = np.sum(weights) assert sum_weights > 0.0 weights /= sum_weights sample_num_per_dataset = [ int(np.ceil(max_sample_num * weight)) for weight in weights ] # Adjust acc_sample_numbers = 0 for i in range(len(sample_num_per_dataset)): sample_numbers[i] = min(sample_num_per_dataset[i], max_sample_num - acc_sample_numbers) acc_sample_numbers += sample_numbers[i] return sample_numbers