import os
from typing import List, Union
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
from loguru import logger
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import (find_files_with_suffix,
is_absolute_path)
from data_juicer.utils.registry import Registry
FORMATTERS = Registry('Formatters')
class BaseFormatter:
"""Base class to load dataset."""
def load_dataset(self, *args) -> Dataset:
raise NotImplementedError
def add_suffixes(datasets: DatasetDict, num_proc: int = 1) -> Dataset:
"""
Add suffix filed to datasets.
:param datasets: a DatasetDict object
:param num_proc: number of processes to add suffixes
:return: datasets with suffix features.
"""
logger.info('Add suffix column for dataset')
from data_juicer.core.data import add_same_content_to_new_column
for key, ds in datasets.items():
if Fields.suffix not in ds.features:
datasets[key] = ds.map(add_same_content_to_new_column,
fn_kwargs={
'new_column_name': Fields.suffix,
'initial_value': '.' + key
},
num_proc=num_proc,
desc='Adding new column for suffix')
datasets = concatenate_datasets([ds for _, ds in datasets.items()])
from data_juicer.core.data import NestedDataset
return NestedDataset(datasets)
def unify_format(
dataset: Dataset,
text_keys: Union[List[str], str] = 'text',
num_proc: int = 1,
global_cfg=None,
) -> Dataset:
"""
Get an unified internal format, conduct the following modifications.
1. check keys of dataset
2. filter out those samples with empty or None text
:param dataset: input dataset
:param text_keys: original text key(s) of dataset.
:param num_proc: number of processes for mapping
:param global_cfg: the global cfg used in consequent processes,
since cfg.text_key may be modified after unifying
:return: unified_format_dataset
"""
from data_juicer.core.data import NestedDataset
if isinstance(dataset, DatasetDict):
datasets = list(dataset.values())
assert len(datasets) == 1, 'Please make sure the passed datasets ' \
'contains only 1 dataset'
dataset = datasets[0]
assert isinstance(dataset, Dataset) or \
isinstance(dataset, NestedDataset), \
'Currently we only support processing data' \
'with huggingface-Dataset format'
if text_keys is None:
text_keys = []
if isinstance(text_keys, str):
text_keys = [text_keys]
logger.info('Unifying the input dataset formats...')
dataset = NestedDataset(dataset)
# 1. check text related keys
for key in text_keys:
if key not in dataset.features:
err_msg = f'There is no key [{key}] in dataset. You might set ' \
f'wrong text_key in the config file for your dataset. ' \
f'Please check and retry!'
logger.error(err_msg)
raise ValueError(err_msg)
# 2. filter out those samples with empty or None text
# TODO: optimize the filtering operation for better efficiency
logger.info(f'There are {len(dataset)} sample(s) in the original dataset.')
def non_empty_text(sample, target_keys):
for target_key in target_keys:
# TODO: case for CFT, in which the len(sample[target_key]) == 0
if sample[target_key] is None:
# we filter out the samples contains at least None column
# since the op can not handle it now
return False
return True
dataset = dataset.filter(non_empty_text,
num_proc=num_proc,
fn_kwargs={'target_keys': text_keys})
logger.info(f'{len(dataset)} samples left after filtering empty text.')
# 3. convert relative paths to absolute paths
if global_cfg:
ds_dir = global_cfg.dataset_dir
image_key = global_cfg.image_key
audio_key = global_cfg.audio_key
video_key = global_cfg.video_key
data_path_keys = []
if image_key in dataset.features:
data_path_keys.append(image_key)
if audio_key in dataset.features:
data_path_keys.append(audio_key)
if video_key in dataset.features:
data_path_keys.append(video_key)
if len(data_path_keys) == 0:
# no image/audio/video path list in dataset, no need to convert
return dataset
if ds_dir == '':
return dataset
logger.info('Converting relative paths in the dataset to their '
'absolute version. (Based on the directory of input '
'dataset file)')
# function to convert relative paths to absolute paths
def rel2abs(sample, path_keys, dataset_dir):
for path_key in path_keys:
if path_key not in sample:
continue
paths = sample[path_key]
if not paths:
continue
new_paths = [
path if os.path.isabs(path) else os.path.join(
dataset_dir, path) for path in paths
]
sample[path_key] = new_paths
return sample
dataset = dataset.map(rel2abs,
num_proc=num_proc,
fn_kwargs={
'path_keys': data_path_keys,
'dataset_dir': ds_dir
})
else:
logger.warning('No global config passed into unify_format function. '
'Relative paths in the dataset might not be converted '
'to their absolute versions. Data of other modalities '
'might not be able to find by Data-Juicer.')
return dataset
def load_formatter(dataset_path,
text_keys=None,
suffixes=None,
add_suffix=False,
**kwargs) -> BaseFormatter:
"""
Load the appropriate formatter for different types of data formats.
:param dataset_path: Path to dataset file or dataset directory
:param text_keys: key names of field that stores sample text.
Default: None
:param suffixes: the suffix of files that will be read. Default:
None
:return: a dataset formatter.
"""
if suffixes is None:
suffixes = []
ext_num = {}
if os.path.isdir(dataset_path) or os.path.isfile(dataset_path):
file_dict = find_files_with_suffix(dataset_path, suffixes)
if not file_dict:
raise IOError(
'Unable to find files matching the suffix from {}'.format(
dataset_path))
for ext in file_dict:
ext_num[ext] = len(file_dict[ext])
# local dataset
if ext_num:
formatter_num = {}
for name, formatter in FORMATTERS.modules.items():
formatter_num[name] = 0
for ext in ext_num:
if ext in formatter.SUFFIXES:
formatter_num[name] += ext_num[ext]
formatter = max(formatter_num, key=lambda x: formatter_num[x])
target_suffixes = set(ext_num.keys()).intersection(
set(FORMATTERS.modules[formatter].SUFFIXES))
return FORMATTERS.modules[formatter](dataset_path,
text_keys=text_keys,
suffixes=target_suffixes,
add_suffix=add_suffix,
**kwargs)
# try huggingface dataset hub
elif not is_absolute_path(dataset_path) and dataset_path.count('/') <= 1:
return RemoteFormatter(dataset_path, text_keys=text_keys, **kwargs)
# no data
else:
raise ValueError(f'Unable to load the dataset from [{dataset_path}]. '
f'It might be because Data-Juicer doesn\'t support '
f'the format of this dataset, or the path of this '
f'dataset is incorrect.Please check if it\'s a valid '
f'dataset path and retry.')