# Some code here has been modified from:
# https://github.com/bigscience-workshop/data-preparation
# --------------------------------------------------------
from collections import defaultdict, deque
from typing import Dict, Optional, Set
import numpy as np
import regex
from loguru import logger
from pydantic import PositiveInt
from data_juicer.utils.constant import HashKeys
from data_juicer.utils.lazy_loader import LazyLoader
from ..base_op import OPERATORS, Deduplicator
from ..common.helper_func import split_on_whitespace
simhash = LazyLoader('simhash', 'simhash')
OP_NAME = 'document_simhash_deduplicator'
[docs]
@OPERATORS.register_module(OP_NAME)
class DocumentSimhashDeduplicator(Deduplicator):
"""Deduplicator to deduplicate samples at document-level using SimHash."""
[docs]
def __init__(self,
tokenization: str = 'space',
window_size: PositiveInt = 6,
lowercase: bool = True,
ignore_pattern: Optional[str] = None,
num_blocks: PositiveInt = 6,
hamming_distance: PositiveInt = 4,
*args,
**kwargs):
"""
Initialization method :param tokenization: tokenization method for
sample texts.
It should be one of [space, punctuation, character]. For
English-like languages, we recommend to use 'space'. And for
Chinese-like languages, we recommend to use 'character'
:param window_size: window size of shingling
:param lowercase: whether to convert text to lower case first
:param ignore_pattern: whether to ignore sub-strings with
specific pattern when computing simhash
:param num_blocks: number of blocks in simhash computing
:param hamming_distance: the max hamming distance threshold in
near-duplicate detection. When the hamming distance of two
sample texts is <= this threshold, they are regarded as
similar samples and this op will only keep one of them after
deduplication. This threshold should be always less than
num_blocks
"""
# about simhash computation
super().__init__(*args, **kwargs)
self.tokenization = tokenization
self.window_size = window_size
self.lowercase = lowercase
self.ignore_pattern = ignore_pattern
if self.ignore_pattern:
self.ignore_pattern = regex.compile(self.ignore_pattern)
# check parameters
if self.ignore_pattern and self.tokenization == 'punctuation':
logger.warning('Be careful that tokenization with punctuations '
'won\'t work if the ignore pattern includes '
'punctuations.')
self.punctuation_pattern = regex.compile(r'\p{P}')
# about deduplication
self.num_blocks = num_blocks
self.hamming_distance = hamming_distance
[docs]
def compute_hash(self, sample):
"""
Compute simhash values for the sample.
:param sample: input sample
:return: sample with simhash value.
"""
# check if it's computed already
if HashKeys.simhash in sample:
return sample
text = sample[self.text_key]
if self.lowercase:
text = text.lower()
if self.ignore_pattern:
text = self.ignore_pattern.sub('', text)
# get tokens for different tokenization method
tokens = []
if self.tokenization == 'character':
tokens = [
str.encode(text[i:i + self.window_size])
for i in range(len(text) - self.window_size)
]
elif self.tokenization == 'punctuation':
tokens = self.punctuation_pattern.split(text)
tokens = [
str.encode(' '.join(tokens[i:i + self.window_size]))
for i in range(len(tokens) - self.window_size)
]
elif self.tokenization == 'space':
tokens = split_on_whitespace(text)
tokens = [
str.encode(' '.join(tokens[i:i + self.window_size]))
for i in range(len(tokens) - self.window_size)
]
else:
raise NotImplementedError(
f'Unimplemented tokenization method [{self.tokenization}]')
# compute simhash
sample[HashKeys.simhash] = str(
np.uint64(simhash.compute(map(simhash.unsigned_hash, tokens))))
return sample
[docs]
def process(self, dataset, show_num=0):
"""
For doc-level, dataset --> dataset.
:param dataset: input dataset
:param show_num: number of traced samples used when tracer is
open.
:return: deduplicated dataset and the sampled duplicate pairs.
"""
# no need to deduplicate because too few samples
if len(dataset) <= 1:
return dataset, {}
# find matches
logger.info(f'Start querying {len(dataset)} samples.')
matches = simhash.find_all(
np.uint64(dataset[HashKeys.simhash]),
self.num_blocks,
self.hamming_distance,
)
logger.info(f'Querying done, found {len(matches)} matches.')
# compute hash diff distribution
graph = defaultdict(dict)
for x, y in matches:
x = str(x)
y = str(y)
graph[x][y] = graph[y][x] = True
hash2ids: Dict[str, Set[str]] = defaultdict(set)
hashes: Set[str] = set(dataset[HashKeys.simhash])
hash2cluster: Dict[str, int] = {}
visited: Set[str] = set()
cluster_id: int = 0
for sid, hash_val in enumerate(dataset[HashKeys.simhash]):
hash2ids[hash_val].add(str(sid))
# clustering
dup_pairs = {} # store duplicate pairs when show_num > 0
while hashes:
hash_val = hashes.pop()
if hash_val in visited:
continue
# if this hash value is not in the matches list, it's regarded as a
# single cluster
if hash_val not in graph:
continue
# Otherwise, BFS to find the cluster
q = deque([hash_val])
visited.add(hash_val)
hash2cluster[hash_val] = cluster_id
if show_num > 0 and len(dup_pairs) < show_num:
dup_pairs[cluster_id] = []
while q:
curr = q.popleft()
for neighbor in graph[curr]:
if neighbor in visited:
continue
visited.add(neighbor)
q.append(neighbor)
hash2cluster[neighbor] = cluster_id
cluster_id += 1
logger.info(f'Found {cluster_id} clusters and {len(graph)} hashes.')
# filter duplicated samples
# NOTICE: For now, we only keep the first sample in a cluster. Maybe
# there are some better strategies later.
def _filter_simhash_dup_helper(sample, visited_clusters,
visited_hashes):
sample_hash_val = sample[HashKeys.simhash]
if sample_hash_val not in hash2cluster:
# single-sample cluster, we need to check hash value still.
if sample_hash_val in visited_hashes:
return False
else:
visited_hashes.add(sample_hash_val)
return True
else:
cluster_num = hash2cluster[sample_hash_val]
if show_num > 0 and cluster_num in dup_pairs \
and len(dup_pairs[cluster_num]) < 2:
dup_pairs[cluster_num].append(sample)
# regular cluster, check cluster number.
if cluster_num in visited_clusters:
return False
else:
visited_clusters.add(cluster_num)
return True
cluster_record = set()
hash_record = set()
dataset = dataset.filter(
_filter_simhash_dup_helper,
fn_kwargs=dict(visited_clusters=cluster_record,
visited_hashes=hash_record),
load_from_cache_file=False if show_num > 0 else True)
logger.info(f'Keep {len(dataset)} samples after SimHash dedup.')
return dataset, dup_pairs