Source code for data_juicer.format.mixture_formatter
from itertools import chain, repeat
from typing import List, Union
import numpy as np
from datasets import Dataset, concatenate_datasets
from loguru import logger
from .formatter import BaseFormatter, load_formatter
[docs]
class MixtureFormatter(BaseFormatter):
"""The class mixes multiple datasets by randomly selecting samples from
every dataset and merging them, and then exports the merged datasset as a
new mixed dataset."""
[docs]
def __init__(self,
dataset_path: str,
suffixes: Union[str, List[str], None] = None,
text_keys=None,
add_suffix=False,
max_samples=None,
**kwargs):
"""
Initialization method.
:param dataset_path: a dataset file or a dataset dir or a list
of them, optional weights, default 1.0 e.g. `<w1> ds.jsonl
<w2> ds_dir <w3> ds_file.json`
:param suffixes: files with specified suffixes to be processed
:param text_keys: key names of field that stores sample text.
:param add_suffix: whether to add the file suffix to dataset
meta info
:param max_samples: max samples number of mixed dataset.
:param kwargs: extra args
"""
data_prefixes, weights = self._get_weight(data_prefix=dataset_path)
sample_numbers = [0] * len(weights)
if max_samples is not None:
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
sample_num_per_dataset = [
int(np.ceil(max_samples * weight)) for weight in weights
]
# Adjust
acc_sample_numbers = 0
for i in range(len(sample_num_per_dataset)):
sample_numbers[i] = min(sample_num_per_dataset[i],
max_samples - acc_sample_numbers)
acc_sample_numbers += sample_numbers[i]
self.sample_numbers = sample_numbers
self.weights = weights
self.formatters = [
load_formatter(dataset_path=data_prefix,
suffixes=suffixes,
text_keys=text_keys,
add_suffix=add_suffix,
**kwargs) for data_prefix in data_prefixes
]
def _get_weight(self, data_prefix):
"""
Split every dataset path and its weight.
:param data_prefix: a dataset file or a dataset dir or a list of
them, e.g. `<w1> ds1.jsonl <w2> ds2_dir <w3> ds3_file.json`
:return: list of dataset path and list of weights
"""
data_prefix = data_prefix.split()
weights = []
prefixes = []
for i in range(len(data_prefix)):
try:
value = max(float(data_prefix[i]), 0.0)
weights.append(value)
except: # noqa: E722
value = data_prefix[i].strip()
# if not set weight, use 1.0 as default
if i == 0 or len(weights) == len(prefixes):
weights.append(1.0)
prefixes.append(value)
return prefixes, weights
[docs]
@classmethod
def random_sample(cls, dataset, weight=1.0, sample_number=0, seed=None):
"""
Randomly sample a subset from a dataset with weight or number,
if sample number is bigger than 0, we will use sample
number instead of weight.
:param dataset: a HuggingFace dataset
:param weight: sample ratio of dataset
:param sample_number: sample number of dataset
:param seed: random sample seed, if None, 42 as default
:return: a subset of dataset
"""
if seed is None:
seed = 42
ds_samples = dataset.num_rows
if sample_number <= 0:
sample_number = int(np.ceil(ds_samples * weight))
if sample_number == ds_samples:
return dataset
sample_index = range(sample_number)
n_repeat = int(np.ceil(sample_number / ds_samples)) - 1
if n_repeat > 0:
remain_samples = sample_number - n_repeat * ds_samples
sample_index = chain(*repeat(range(ds_samples), n_repeat),
range(remain_samples))
return dataset.shuffle(seed=seed).select(sample_index)
[docs]
def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
"""
Load a mixed dataset.
:param num_proc: number of processes when loading the dataset
:param global_cfg: the global cfg used in consequent processes,
:return: mixed dataset
"""
dataset_list = []
for weight, sample_num, formatter in zip(self.weights,
self.sample_numbers,
self.formatters):
dataset = formatter.load_dataset(num_proc, global_cfg)
sampled = self.random_sample(dataset, weight, sample_num)
logger.info(f'sampled {len(sampled)} from '
f'{len(dataset)}')
dataset_list.append(sampled)
from data_juicer.core.data import NestedDataset
mixed_dataset = NestedDataset(concatenate_datasets(dataset_list))
logger.info(f'There are {len(mixed_dataset)} in final dataset')
return mixed_dataset