Source code for data_juicer.ops.deduplicator.document_deduplicator

# Some code here has been modified from:
# https://github.com/bigscience-workshop/data-preparation/blob/main/preprocessing/training/01a_catalogue_cleaning_and_filtering/clean_helpers/deduplication.py
# --------------------------------------------------------

import hashlib
import string
from collections import defaultdict
from typing import Dict, Set

import regex as re

from data_juicer.utils.constant import HashKeys

from ..base_op import OPERATORS, Deduplicator


[docs]@OPERATORS.register_module('document_deduplicator') class DocumentDeduplicator(Deduplicator): """ Deduplicator to deduplicate samples at document-level using exact matching. Using md5 hash to deduplicate samples. """
[docs] def __init__(self, lowercase: bool = False, ignore_non_character: bool = False, *args, **kwargs): """ Initialization method. :param lowercase: Whether to convert sample text to lower case :param ignore_non_character: Whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations :param args: extra args :param kwargs: extra args. """ super().__init__(*args, **kwargs) self.lowercase = lowercase self.remove_non_character_regex = re.compile( f'\s+|\d+|[{re.escape(string.punctuation)}]' # noqa: W605 ) if ignore_non_character else None
[docs] def compute_hash(self, sample): """ Compute md5 hash values for the sample. :param sample: input sample :return: sample with md5 hash value. """ # check if it's computed already if HashKeys.hash in sample: return sample text = sample[self.text_key] if self.lowercase: text = text.lower() if self.remove_non_character_regex: text = self.remove_non_character_regex.sub('', text) def _get_hash(txt): return hashlib.md5(txt.strip().encode('utf-8')).hexdigest() sample[HashKeys.hash] = _get_hash(text) return sample
[docs] def process(self, dataset, show_num=0): """ For doc-level, dataset --> dataset. :param dataset: input dataset :param show_num: number of traced samples used when tracer is open. :return: deduplicated dataset and the sampled duplicate pairs. """ # no need to deduplicate because too few samples if len(dataset) <= 1: return dataset, {} dup_hashes = None if show_num > 0: # sample duplicate pairs hash2ids: Dict[int, Set[int]] = defaultdict(set) for sid, hash_val in enumerate(dataset[HashKeys.hash]): hash2ids[hash_val].add(sid) dup_samples = sorted(list(hash2ids.items()), key=lambda x: len(x[1]), reverse=True) dup_hashes = set([ item[0] for item in dup_samples if len(item[1]) > 1 ][:show_num]) def _filter_dup_helper(sample, hashes): hash = sample[HashKeys.hash] if show_num > 0 and hash in dup_hashes \ and len(dup_pairs[hash]) < 2: # tracer is open and not enough duplicate sample pairs dup_pairs[hash].append(sample) if hash in hashes: return False else: hashes.add(hash) return True hashes = set() dup_pairs = {hash_v: [] for hash_v in dup_hashes} if dup_hashes else {} dataset = dataset.filter( _filter_dup_helper, fn_kwargs=dict(hashes=hashes), load_from_cache_file=False if show_num > 0 else True) # num_proc=1 return dataset, dup_pairs