import re
from itertools import chain
from typing import Union
from pydantic import NonNegativeInt, PositiveInt
from data_juicer.utils.model_utils import get_model, prepare_model
from ..base_op import OPERATORS, Mapper
OP_NAME = 'text_chunk_mapper'
[docs]
@OPERATORS.register_module(OP_NAME)
class TextChunkMapper(Mapper):
"""Split input text to chunks."""
_batched_op = True
[docs]
def __init__(self,
max_len: Union[PositiveInt, None] = None,
split_pattern: Union[str, None] = r'\n\n',
overlap_len: NonNegativeInt = 0,
tokenizer: Union[str, None] = None,
trust_remote_code: bool = False,
*args,
**kwargs):
"""
Initialization method.
:param max_len: Split text into multi texts with this max len if not
None.
:param split_pattern: Make sure split in this pattern if it is not None
and force cut if the length exceeds max_len.
:param overlap_len: Overlap length of the split texts if not split in
the split pattern.
:param tokenizer: The tokenizer name of Hugging Face tokenizers.
The text length will be calculate as the token num if it is offerd.
Otherwise, the text length equals to string length. Support
tiktoken tokenizer (such as gpt-4o), dashscope tokenizer (such as
qwen2.5-72b-instruct) and huggingface tokenizer.
:trust_remote_code: for loading huggingface model
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
if max_len is None and split_pattern is None:
raise ValueError('max_len and split_pattern cannot be both None')
if max_len is not None and overlap_len >= max_len:
raise ValueError('overlap_len must be less than max_len')
self.max_len = max_len
self.overlap_len = overlap_len
self.split_pattern = split_pattern
self.tokenizer_name = tokenizer
if tokenizer is not None:
self.model_key = prepare_model(
model_type='api',
model=tokenizer,
return_processor=True,
processor_config={'trust_remote_code': trust_remote_code})
[docs]
def recursively_chunk(self, text):
if self.tokenizer_name is not None:
_, tokenizer = get_model(self.model_key)
tokens = tokenizer.encode(text)
total_len = len(tokens)
sub_text = tokenizer.decode(tokens[:self.max_len])
else:
total_len = len(text)
sub_text = text[:self.max_len]
if total_len <= self.max_len:
return [text]
matches = list(re.finditer(self.split_pattern, sub_text))
if not matches:
cur_text = sub_text
if self.tokenizer_name is not None:
left_text = tokenizer.decode(tokens[self.max_len -
self.overlap_len:])
else:
left_text = text[self.max_len - self.overlap_len:]
else:
last_match = matches[-1]
cur_text = sub_text[:last_match.start()]
left_text = text[last_match.end():]
return [cur_text] + self.recursively_chunk(left_text)
[docs]
def get_text_chunks(self, text, rank=None):
if self.split_pattern is not None and self.max_len is None:
chunks = re.split(f'({self.split_pattern})', text)
chunks = [t for t in chunks if t.strip()]
elif self.split_pattern is None and self.max_len is not None:
tokens = text
total_len = len(text)
if self.tokenizer_name is not None:
_, tokenizer = get_model(self.model_key, rank=rank)
tokens = tokenizer.encode(text)
total_len = len(tokens)
if total_len <= self.max_len:
return [text]
chunks = []
for start in range(0, total_len, self.max_len - self.overlap_len):
cur = tokens[start:start + self.max_len]
if self.tokenizer_name is not None:
cur = tokenizer.decode(cur)
chunks.append(cur)
else:
chunks = self.recursively_chunk(text)
return chunks
[docs]
def process_batched(self, samples, rank=None):
sample_num = len(samples[self.text_key])
samples[self.text_key] = [
self.get_text_chunks(text, rank=rank)
for text in samples[self.text_key]
]
for key in samples:
if key != self.text_key:
samples[key] = [[samples[key][i]] *
len(samples[self.text_key][i])
for i in range(sample_num)]
for key in samples:
samples[key] = list(chain(*samples[key]))
return samples