Source code for data_juicer.analysis.collector

from itertools import chain

from data_juicer.format import load_formatter
from data_juicer.utils.lazy_loader import LazyLoader

torch = LazyLoader('torch', 'torch')
transformers = LazyLoader('transformers', 'transformers')


[docs] class TextTokenDistCollector(object): """Tokenize and collect distribution of tokens for given dataset with a specified tokenizer. """
[docs] def __init__(self, tokenizer): """ Initialization method. :param tokenizer: tokenizer name on huggingface """ self.tokenizer = transformers.AutoTokenizer.from_pretrained( tokenizer, trust_remote_code=True) self.vocab_size = len(self.tokenizer)
[docs] def collect(self, data_path, text_key, num_proc=1) -> 'torch.distributions.Categorical': """ Tokenize and collect tokens distribution of input dataset :param data_path: path to input dataset. :param text_key: field keys that will be considered into token counts. :param num_proc: number of processes to count tokens. :return: token distribution. """ formatter = load_formatter(data_path) dataset = formatter.load_dataset(num_proc=num_proc) assert text_key in dataset.features, f'[{text_key} not find in dataset' def prepare_tokenizer( tokenizer, text_key, ): """ Prepare a tokenizer function for dataset. :param tokenizer: a tokenizer to tokenize sample. :param text_key: field keys that will be considered into token counts. """ def _tokenize_fn(example, ): example = tokenizer(example[text_key], add_special_tokens=False) return example return _tokenize_fn tokenize_proc = prepare_tokenizer(self.tokenizer, text_key) dataset = dataset.map(tokenize_proc, num_proc=num_proc, desc=f'tokenize {data_path.split("/")[-1]}') token_count = torch.zeros(self.vocab_size, dtype=torch.int64) token_ids = torch.tensor( list(chain.from_iterable(dataset['input_ids']))) indices, counts = token_ids.unique(return_counts=True) token_count.scatter_(0, indices, counts.to(token_count.dtype)) dist = torch.distributions.Categorical(token_count) return dist