from itertools import chain
from data_juicer.format import load_formatter
from data_juicer.utils.lazy_loader import LazyLoader
torch = LazyLoader('torch', 'torch')
transformers = LazyLoader('transformers', 'transformers')
[docs]
class TextTokenDistCollector(object):
"""Tokenize and collect distribution of tokens for given
dataset with a specified tokenizer.
"""
[docs]
def __init__(self, tokenizer):
"""
Initialization method.
:param tokenizer: tokenizer name on huggingface
"""
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer, trust_remote_code=True)
self.vocab_size = len(self.tokenizer)
[docs]
def collect(self,
data_path,
text_key,
num_proc=1) -> 'torch.distributions.Categorical':
"""
Tokenize and collect tokens distribution of input dataset
:param data_path: path to input dataset.
:param text_key: field keys that will be considered into token counts.
:param num_proc: number of processes to count tokens.
:return: token distribution.
"""
formatter = load_formatter(data_path)
dataset = formatter.load_dataset(num_proc=num_proc)
assert text_key in dataset.features, f'[{text_key} not find in dataset'
def prepare_tokenizer(
tokenizer,
text_key,
):
"""
Prepare a tokenizer function for dataset.
:param tokenizer: a tokenizer to tokenize sample.
:param text_key: field keys that will be
considered into token counts.
"""
def _tokenize_fn(example, ):
example = tokenizer(example[text_key],
add_special_tokens=False)
return example
return _tokenize_fn
tokenize_proc = prepare_tokenizer(self.tokenizer, text_key)
dataset = dataset.map(tokenize_proc,
num_proc=num_proc,
desc=f'tokenize {data_path.split("/")[-1]}')
token_count = torch.zeros(self.vocab_size, dtype=torch.int64)
token_ids = torch.tensor(
list(chain.from_iterable(dataset['input_ids'])))
indices, counts = token_ids.unique(return_counts=True)
token_count.scatter_(0, indices, counts.to(token_count.dtype))
dist = torch.distributions.Categorical(token_count)
return dist