# Some code here has been modified from:
# https://huggingface.co/spaces/huggingface/text-data-filtering
# --------------------------------------------------------
import sys
from ..base_op import OPERATORS, Mapper
from ..common import (SPECIAL_CHARACTERS, merge_on_whitespace_tab_newline,
split_on_newline_tab_whitespace, strip)
[docs]
@OPERATORS.register_module('remove_long_words_mapper')
class RemoveLongWordsMapper(Mapper):
"""Mapper to remove long words within a specific range."""
_batched_op = True
[docs]
def __init__(self,
min_len: int = 1,
max_len: int = sys.maxsize,
*args,
**kwargs):
"""
Initialization method.
:param min_len: The min mapper word length in this op, words
will be filtered if their length is below this parameter.
:param max_len: The max mapper word length in this op, words
will be filtered if their length exceeds this parameter.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_len = min_len
self.max_len = max_len
[docs]
def should_keep_long_word(self, word):
if self.min_len <= len(word) <= self.max_len:
return True
elif self.min_len <= len(strip(word,
SPECIAL_CHARACTERS)) <= self.max_len:
return True
else:
return False
[docs]
def process_batched(self, samples):
for idx, text in enumerate(samples[self.text_key]):
sentences = split_on_newline_tab_whitespace(text)
sentences = [[[
word for word in subsentence
if self.should_keep_long_word(word)
] for subsentence in sentence] for sentence in sentences]
samples[self.text_key][idx] = merge_on_whitespace_tab_newline(
sentences)
return samples