Source code for data_juicer.format.formatter
import os
from typing import List, Union
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
from loguru import logger
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import find_files_with_suffix, is_absolute_path
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.registry import Registry
FORMATTERS = Registry("Formatters")
[docs]
class LocalFormatter(BaseFormatter):
"""The class is used to load a dataset from local files or local
directory."""
[docs]
def __init__(
self,
dataset_path: str,
type: str,
suffixes: Union[str, List[str], None] = None,
text_keys: List[str] = None,
add_suffix=False,
**kwargs,
):
"""
Initialization method.
:param dataset_path: path to a dataset file or a dataset
directory
:param type: a packaged dataset module type (json, csv, etc.)
:param suffixes: files with specified suffixes to be processed
:param text_keys: key names of field that stores sample
text.
:param add_suffix: whether to add the file suffix to dataset
meta info
:param kwargs: extra args
"""
self.type = type
self.kwargs = kwargs
self.text_keys = text_keys
self.data_files = find_files_with_suffix(dataset_path, suffixes)
self.add_suffix = add_suffix
[docs]
def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
"""
Load a dataset from dataset file or dataset directory, and unify its
format.
:param num_proc: number of processes when loading the dataset
:param global_cfg: global cfg used in consequent processes,
:return: formatted dataset
"""
_num_proc = self.kwargs.pop("num_proc", 1)
num_proc = num_proc or _num_proc
datasets = load_dataset(
self.type,
data_files={key.strip("."): self.data_files[key] for key in self.data_files},
num_proc=num_proc,
**self.kwargs,
)
if self.add_suffix:
logger.info("Add suffix info into dataset...")
datasets = add_suffixes(datasets, num_proc)
else:
from data_juicer.core.data import NestedDataset
datasets = NestedDataset(concatenate_datasets([ds for _, ds in datasets.items()]))
ds = unify_format(datasets, text_keys=self.text_keys, num_proc=num_proc, global_cfg=global_cfg)
return ds
[docs]
class RemoteFormatter(BaseFormatter):
"""The class is used to load a dataset from repository of huggingface
hub."""
[docs]
def __init__(self, dataset_path: str, text_keys: List[str] = None, **kwargs):
"""
Initialization method.
:param dataset_path: a dataset file or a dataset directory
:param text_keys: key names of field that stores sample
text.
:param kwargs: extra args
"""
self.path = dataset_path
self.text_keys = text_keys
self.kwargs = kwargs
[docs]
def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
"""
Load a dataset from HuggingFace, and unify its format.
:param num_proc: number of processes when loading the dataset
:param global_cfg: the global cfg used in consequent processes,
:return: formatted dataset
"""
ds = load_dataset(self.path, split="train", num_proc=num_proc, **self.kwargs)
ds = unify_format(ds, text_keys=self.text_keys, num_proc=num_proc, global_cfg=global_cfg)
return ds
[docs]
def add_suffixes(datasets: DatasetDict, num_proc: int = 1) -> Dataset:
"""
Add suffix filed to datasets.
:param datasets: a DatasetDict object
:param num_proc: number of processes to add suffixes
:return: datasets with suffix features.
"""
logger.info("Add suffix column for dataset")
from data_juicer.core.data import add_same_content_to_new_column
for key, ds in datasets.items():
if Fields.suffix not in ds.features:
datasets[key] = ds.map(
add_same_content_to_new_column,
fn_kwargs={"new_column_name": Fields.suffix, "initial_value": "." + key},
num_proc=num_proc,
desc="Adding new column for suffix",
)
datasets = concatenate_datasets([ds for _, ds in datasets.items()])
from data_juicer.core.data import NestedDataset
return NestedDataset(datasets)
[docs]
def unify_format(
dataset: Dataset,
text_keys: Union[List[str], str] = "text",
num_proc: int = 1,
global_cfg=None,
) -> Dataset:
"""
Get an unified internal format, conduct the following modifications.
1. check keys of dataset
2. filter out those samples with empty or None text
:param dataset: input dataset
:param text_keys: original text key(s) of dataset.
:param num_proc: number of processes for mapping
:param global_cfg: the global cfg used in consequent processes,
since cfg.text_key may be modified after unifying
:return: unified_format_dataset
"""
from data_juicer.core.data import NestedDataset
if isinstance(dataset, DatasetDict):
datasets = list(dataset.values())
assert len(datasets) == 1, "Please make sure the passed datasets " "contains only 1 dataset"
dataset = datasets[0]
assert isinstance(dataset, Dataset) or isinstance(dataset, NestedDataset), (
"Currently we only support processing data" "with huggingface-Dataset format"
)
if text_keys is None:
text_keys = []
if isinstance(text_keys, str):
text_keys = [text_keys]
logger.info("Unifying the input dataset formats...")
dataset = NestedDataset(dataset)
# 1. check text related keys
for key in text_keys:
if key not in dataset.features:
err_msg = (
f"There is no key [{key}] in dataset. You might set "
f"wrong text_key in the config file for your dataset. "
f"Please check and retry!"
)
logger.error(err_msg)
raise ValueError(err_msg)
# 2. filter out those samples with empty or None text
# TODO: optimize the filtering operation for better efficiency
logger.info(f"There are {len(dataset)} sample(s) in the original dataset.")
def non_empty_text(sample, target_keys):
for target_key in target_keys:
# TODO: case for CFT, in which the len(sample[target_key]) == 0
if sample[target_key] is None:
# we filter out the samples contains at least None column
# since the op can not handle it now
return False
return True
dataset = dataset.filter(non_empty_text, num_proc=num_proc, fn_kwargs={"target_keys": text_keys})
logger.info(f"{len(dataset)} samples left after filtering empty text.")
# 3. convert relative paths to absolute paths
if global_cfg:
# check and get dataset dir
if global_cfg.get("dataset_path", None) and os.path.exists(global_cfg.dataset_path):
if os.path.isdir(global_cfg.dataset_path):
ds_dir = global_cfg.dataset_path
else:
ds_dir = os.path.dirname(global_cfg.dataset_path)
else:
ds_dir = ""
image_key = global_cfg.get("image_key", SpecialTokens.image)
audio_key = global_cfg.get("audio_key", SpecialTokens.audio)
video_key = global_cfg.get("video_key", SpecialTokens.video)
data_path_keys = []
if image_key in dataset.features:
data_path_keys.append(image_key)
if audio_key in dataset.features:
data_path_keys.append(audio_key)
if video_key in dataset.features:
data_path_keys.append(video_key)
if len(data_path_keys) == 0:
# no image/audio/video path list in dataset, no need to convert
return dataset
if ds_dir == "":
return dataset
logger.info(
"Converting relative paths in the dataset to their "
"absolute version. (Based on the directory of input "
"dataset file)"
)
# function to convert relative paths to absolute paths
def rel2abs(sample, path_keys, dataset_dir):
for path_key in path_keys:
if path_key not in sample:
continue
paths = sample[path_key]
if not paths:
continue
new_paths = [path if is_absolute_path(path) else os.path.join(dataset_dir, path) for path in paths]
sample[path_key] = new_paths
return sample
dataset = dataset.map(
rel2abs, num_proc=num_proc, fn_kwargs={"path_keys": data_path_keys, "dataset_dir": ds_dir}
)
else:
logger.warning(
"No global config passed into unify_format function. "
"Relative paths in the dataset might not be converted "
"to their absolute versions. Data of other modalities "
"might not be able to find by Data-Juicer."
)
return dataset