import os
import shlex
from argparse import Namespace
from typing import List, Tuple
import numpy as np
from datasets import concatenate_datasets
from loguru import logger
from data_juicer.core.data import DJDataset, NestedDataset
from data_juicer.core.data.config_validator import ConfigValidationError
from data_juicer.core.data.data_validator import DataValidatorRegistry
from data_juicer.core.data.load_strategy import DataLoadStrategyRegistry
from data_juicer.utils.file_utils import is_absolute_path
from data_juicer.utils.sample import random_sample
[docs]
class DatasetBuilder(object):
"""
DatasetBuilder is a class that builds a dataset from a configuration.
"""
[docs]
def __init__(self, cfg: Namespace, executor_type: str = 'default'):
self.use_generated_dataset_config = False
self.cfg = cfg
self.executor_type = executor_type
self.require_dataset_arg = False
# priority: generated_dataset_config > dataset_path > dataset
if hasattr(
cfg,
'generated_dataset_config') and cfg.generated_dataset_config:
self.use_generated_dataset_config = True
self.generated_dataset_config = cfg.generated_dataset_config
return
elif hasattr(cfg, 'dataset_path') and cfg.dataset_path:
logger.info(f'found dataset_path setting: {cfg.dataset_path}')
ds_configs = rewrite_cli_datapath(cfg.dataset_path)
elif hasattr(cfg, 'dataset') and cfg.dataset:
logger.info(f'found dataset setting: {cfg.dataset}')
ds_configs = cfg.dataset
else:
logger.warning(
'No dataset setting found in configurations. Will '
'check the dataset argument before loading dataset.')
self.require_dataset_arg = True
return
# validate dataset config for type constraints
# TODO other constraints; ray dataset only supports local, etc.
if not isinstance(ds_configs, dict):
raise ConfigValidationError(
'Dataset config should be a dictionary')
if 'configs' not in ds_configs:
raise ConfigValidationError(
'Dataset config should have a "configs" key')
if (not isinstance(ds_configs['configs'], list)
or len(ds_configs['configs']) == 0):
raise ConfigValidationError(
'Dataset config "configs" should be a non-empty list')
if ('max_sample_num' in ds_configs
and (not isinstance(ds_configs['max_sample_num'], int)
or ds_configs['max_sample_num'] <= 0)):
raise ConfigValidationError(
'Dataset config "max_sample_num" should be a positive integer')
for ds_config in ds_configs['configs']:
if not isinstance(ds_config, dict):
raise ConfigValidationError(
'Dataset configs should be dictionaries')
types = [
ds_config.get('type', None) for ds_config in ds_configs['configs']
]
if len(set(types)) > 1:
raise ConfigValidationError(
'Mixture of diff types (LOCAL/REMOTE/...) are not supported')
if types[0] == 'remote' and len(ds_configs['configs']) > 1:
raise ConfigValidationError(
'Multiple remote datasets are not supported')
# initialize the data load strategies
self.load_strategies = []
for ds_config in ds_configs['configs']:
# initialize data loading strategy
data_type = ds_config.get('type', None)
data_source = ds_config.get('source', None)
stra = DataLoadStrategyRegistry.get_strategy_class(
self.executor_type, data_type, data_source)(ds_config,
cfg=self.cfg)
if stra is None:
raise ValueError(f'No data load strategy found for'
f' {data_type} {data_source}')
self.load_strategies.append(stra)
# failed to initialize any load strategy
if not self.load_strategies:
logger.error(f'No data load strategies found for {ds_configs}')
raise ConfigValidationError('No data load strategies found')
# initialzie the sample numbers
self.max_sample_num = ds_configs.get('max_sample_num', None)
# get weights and sample numbers
if self.max_sample_num:
self.weights = [stra.weight for stra in self.load_strategies]
self.sample_numbers = get_sample_numbers(self.weights,
self.max_sample_num)
else:
self.weights = [1.0 for stra in self.load_strategies]
self.sample_numbers = [None for stra in self.load_strategies]
# initialize data validators
self.validators = []
if hasattr(cfg, 'validators'):
for validator_config in cfg.validators:
if 'type' not in validator_config:
raise ValueError('Validator config must have a "type" key')
validator_type = validator_config['type']
validator_cls = DataValidatorRegistry.get_validator(
validator_type)
if validator_cls:
self.validators.append(validator_cls(validator_config))
else:
raise ValueError(
f'No data validator found for {validator_type}')
[docs]
def load_dataset(self, **kwargs) -> DJDataset:
if self.require_dataset_arg:
# should not get into this method
raise ValueError(
'Unable to load dataset; should have one of '
'generated_dataset_config, dataset_path, or dataset '
'in configurations, or pass the `dataset` object through `run`'
' method')
# if generated_dataset_config present, prioritize
if self.use_generated_dataset_config:
return DatasetBuilder.load_dataset_by_generated_config(
self.generated_dataset_config)
_datasets = []
# load datasets with sample numbers
for stra, weight, sample_num in zip(self.load_strategies, self.weights,
self.sample_numbers):
# load dataset with its load strategy
dataset = stra.load_data(**kwargs)
# do data validation
for validator in self.validators:
validator.validate(dataset)
# do data sampling, if necessary
if self.max_sample_num:
dataset = random_sample(dataset, weight, sample_num)
_datasets.append(dataset)
# handle data mixture
if self.executor_type == 'default':
return NestedDataset(concatenate_datasets(_datasets))
elif self.executor_type == 'ray':
# TODO: support multiple datasets and mixing for ray
assert len(
_datasets) == 1, 'Ray setup only supports one dataset now'
return _datasets[0]
[docs]
@classmethod
def load_dataset_by_generated_config(cls, generated_dataset_config):
"""
load dataset by generated config
"""
assert isinstance(generated_dataset_config,
dict) and 'type' in generated_dataset_config
args = generated_dataset_config.copy()
# TODO finish the auto local dataset part
obj_name = args.pop('type')
from data_juicer.format.formatter import FORMATTERS
dataset = FORMATTERS.modules[obj_name](**args).load_dataset()
return dataset
[docs]
def rewrite_cli_datapath(dataset_path, max_sample_num=None) -> List:
"""
rewrite the dataset_path from CLI into proper dataset config format
that is compatible with YAML config style; retrofitting CLI input
of local files and huggingface path
:param dataset_path: a dataset file or a dataset dir or a list of
them, e.g. `<w1> ds1.jsonl <w2> ds2_dir <w3> ds3_file.json`
:param max_sample_num: the maximum number of samples to load
:return: list of dataset configs
"""
paths, weights = parse_cli_datapath(dataset_path)
ret = ({
'configs': [],
'max_sample_num': max_sample_num
} if max_sample_num else {
'configs': []
})
for p, w in zip(paths, weights):
if os.path.isdir(p) or os.path.isfile(p):
# local files
ret['configs'].append({'type': 'local', 'path': p, 'weight': w})
elif (not is_absolute_path(p) and not p.startswith('.')
and p.count('/') <= 1):
# remote huggingface
ret['configs'].append({
'type': 'huggingface',
'path': p,
'split': 'train'
})
else:
#
raise ValueError(
f'Unable to load the dataset from [{dataset_path}]. '
f'Data-Juicer CLI mode only supports local files '
f'w or w/o weights, or huggingface path')
return ret
[docs]
def parse_cli_datapath(dataset_path) -> Tuple[List[str], List[float]]:
"""
Split every dataset path and its weight.
:param dataset_path: a dataset file or a dataset dir or a list of
them, e.g. `<w1> ds1.jsonl <w2> ds2_dir <w3> ds3_file.json`
:return: list of dataset path and list of weights
"""
# Handle empty input
if not dataset_path or not dataset_path.strip():
return [], []
# Use shlex to properly handle quoted strings
try:
tokens = shlex.split(dataset_path)
except ValueError as e:
raise ValueError(f'Invalid dataset path format: {e}')
prefixes = []
weights = []
for i in range(len(tokens)):
try:
value = max(float(tokens[i]), 0.0)
weights.append(value)
except: # noqa: E722
value = tokens[i].strip()
# if not set weight, use 1.0 as default
if i == 0 or len(weights) == len(prefixes):
weights.append(1.0)
prefixes.append(value)
return prefixes, weights
[docs]
def get_sample_numbers(weights, max_sample_num):
sample_numbers = [0] * len(weights)
# Normalize weights
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
sample_num_per_dataset = [
int(np.ceil(max_sample_num * weight)) for weight in weights
]
# Adjust
acc_sample_numbers = 0
for i in range(len(sample_num_per_dataset)):
sample_numbers[i] = min(sample_num_per_dataset[i],
max_sample_num - acc_sample_numbers)
acc_sample_numbers += sample_numbers[i]
return sample_numbers