Source code for data_juicer.utils.sample

from itertools import chain, repeat

import numpy as np


[docs] def random_sample(dataset, weight=1.0, sample_number=0, seed=None): """ Randomly sample a subset from a dataset with weight or number, if sample number is bigger than 0, we will use sample number instead of weight. :param dataset: a HuggingFace dataset :param weight: sample ratio of dataset :param sample_number: sample number of dataset :param seed: random sample seed, if None, 42 as default :return: a subset of dataset """ if seed is None: seed = 42 ds_samples = dataset.num_rows if sample_number <= 0: sample_number = int(np.ceil(ds_samples * weight)) if sample_number == ds_samples: return dataset sample_index = range(sample_number) n_repeat = int(np.ceil(sample_number / ds_samples)) - 1 if n_repeat > 0: remain_samples = sample_number - n_repeat * ds_samples sample_index = chain(*repeat(range(ds_samples), n_repeat), range(remain_samples)) return dataset.shuffle(seed=seed).select(sample_index)