Source code for data_juicer.ops.deduplicator.document_simhash_deduplicator

# Some code here has been modified from:
# https://github.com/bigscience-workshop/data-preparation
# --------------------------------------------------------

from collections import defaultdict, deque
from typing import Dict, Optional, Set

import numpy as np
import regex
from loguru import logger
from pydantic import PositiveInt

from data_juicer.utils.constant import HashKeys
from data_juicer.utils.lazy_loader import LazyLoader

from ..base_op import OPERATORS, Deduplicator
from ..common.helper_func import split_on_whitespace

simhash = LazyLoader("simhash", "simhash-pybind")

OP_NAME = "document_simhash_deduplicator"


[docs] @OPERATORS.register_module(OP_NAME) class DocumentSimhashDeduplicator(Deduplicator): """Deduplicates samples at the document level using SimHash. This operator computes SimHash values for each sample and removes duplicates based on a specified Hamming distance threshold. It supports different tokenization methods: 'space', 'punctuation', and 'character'. The SimHash is computed over shingles of a given window size, and the deduplication process clusters similar documents and retains only one from each cluster. The default mode converts text to lowercase and can ignore specific patterns. The key metric, Hamming distance, is used to determine similarity between SimHash values. Important notes: - The `ignore_pattern` parameter can be used to exclude certain substrings during SimHash computation. - For punctuation-based tokenization, the `ignore_pattern` should not include punctuations to avoid conflicts. - The `hamming_distance` must be less than the number of blocks (`num_blocks`). - Only the first sample in each cluster is retained by default."""
[docs] def __init__( self, tokenization: str = "space", window_size: PositiveInt = 6, lowercase: bool = True, ignore_pattern: Optional[str] = None, num_blocks: PositiveInt = 6, hamming_distance: PositiveInt = 4, *args, **kwargs, ): """ Initialization method :param tokenization: tokenization method for sample texts. It should be one of [space, punctuation, character]. For English-like languages, we recommend to use 'space'. And for Chinese-like languages, we recommend to use 'character' :param tokenization: tokenization method for sample texts :param window_size: window size of shingling :param lowercase: whether to convert text to lower case first :param ignore_pattern: whether to ignore sub-strings with specific pattern when computing simhash :param num_blocks: number of blocks in simhash computing :param hamming_distance: the max hamming distance threshold in near-duplicate detection. When the hamming distance of two sample texts is <= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication. This threshold should be always less than num_blocks """ # about simhash computation super().__init__(*args, **kwargs) self.tokenization = tokenization self.window_size = window_size self.lowercase = lowercase self.ignore_pattern = ignore_pattern if self.ignore_pattern: self.ignore_pattern = regex.compile(self.ignore_pattern) # check parameters if self.ignore_pattern and self.tokenization == "punctuation": logger.warning( "Be careful that tokenization with punctuations " "won't work if the ignore pattern includes " "punctuations." ) self.punctuation_pattern = regex.compile(r"\p{P}") # about deduplication self.num_blocks = num_blocks self.hamming_distance = hamming_distance
[docs] def compute_hash(self, sample): """ Compute simhash values for the sample. :param sample: input sample :return: sample with simhash value. """ # check if it's computed already if HashKeys.simhash in sample: return sample text = sample[self.text_key] if self.lowercase: text = text.lower() if self.ignore_pattern: text = self.ignore_pattern.sub("", text) # get tokens for different tokenization method tokens = [] if self.tokenization == "character": tokens = [str.encode(text[i : i + self.window_size]) for i in range(len(text) - self.window_size + 1)] elif self.tokenization == "punctuation": tokens = self.punctuation_pattern.split(text) tokens = [ str.encode(" ".join(tokens[i : i + self.window_size])) for i in range(len(tokens) - self.window_size + 1) ] elif self.tokenization == "space": tokens = split_on_whitespace(text) tokens = [ str.encode(" ".join(tokens[i : i + self.window_size])) for i in range(len(tokens) - self.window_size + 1) ] else: raise NotImplementedError(f"Unimplemented tokenization method [{self.tokenization}]") # compute simhash sample[HashKeys.simhash] = str(np.uint64(simhash.compute(map(simhash.unsigned_hash, tokens)))) return sample
[docs] def process(self, dataset, show_num=0): """ For doc-level, dataset --> dataset. :param dataset: input dataset :param show_num: number of traced samples used when tracer is open. :return: deduplicated dataset and the sampled duplicate pairs. """ # no need to deduplicate because too few samples if len(dataset) <= 1: return dataset, {} # find matches logger.info(f"Start querying {len(dataset)} samples.") matches = simhash.find_all( np.uint64(dataset[HashKeys.simhash]), self.num_blocks, self.hamming_distance, ) logger.info(f"Querying done, found {len(matches)} matches.") # compute hash diff distribution graph = defaultdict(dict) for x, y in matches: x = str(x) y = str(y) graph[x][y] = graph[y][x] = True hash2ids: Dict[str, Set[str]] = defaultdict(set) hashes: Set[str] = set(dataset[HashKeys.simhash]) hash2cluster: Dict[str, int] = {} visited: Set[str] = set() cluster_id: int = 0 for sid, hash_val in enumerate(dataset[HashKeys.simhash]): hash2ids[hash_val].add(str(sid)) # clustering dup_pairs = {} # store duplicate pairs when show_num > 0 while hashes: hash_val = hashes.pop() if hash_val in visited: continue # if this hash value is not in the matches list, it's regarded as a # single cluster if hash_val not in graph: continue # Otherwise, BFS to find the cluster q = deque([hash_val]) visited.add(hash_val) hash2cluster[hash_val] = cluster_id if show_num > 0 and len(dup_pairs) < show_num: dup_pairs[cluster_id] = [] while q: curr = q.popleft() for neighbor in graph[curr]: if neighbor in visited: continue visited.add(neighbor) q.append(neighbor) hash2cluster[neighbor] = cluster_id cluster_id += 1 logger.info(f"Found {cluster_id} clusters and {len(graph)} hashes.") # filter duplicated samples # NOTICE: For now, we only keep the first sample in a cluster. Maybe # there are some better strategies later. def _filter_simhash_dup_helper(sample, visited_clusters, visited_hashes): sample_hash_val = sample[HashKeys.simhash] if sample_hash_val not in hash2cluster: # single-sample cluster, we need to check hash value still. if sample_hash_val in visited_hashes: return False else: visited_hashes.add(sample_hash_val) return True else: cluster_num = hash2cluster[sample_hash_val] if show_num > 0 and cluster_num in dup_pairs and len(dup_pairs[cluster_num]) < 2: dup_pairs[cluster_num].append(sample) # regular cluster, check cluster number. if cluster_num in visited_clusters: return False else: visited_clusters.add(cluster_num) return True cluster_record = set() hash_record = set() dataset = dataset.filter( _filter_simhash_dup_helper, fn_kwargs=dict(visited_clusters=cluster_record, visited_hashes=hash_record), load_from_cache_file=False if show_num > 0 else True, ) logger.info(f"Keep {len(dataset)} samples after SimHash dedup.") return dataset, dup_pairs