Source code for data_juicer.ops.deduplicator.document_simhash_deduplicator
# Some code here has been modified from:# https://github.com/bigscience-workshop/data-preparation# --------------------------------------------------------fromcollectionsimportdefaultdict,dequefromtypingimportDict,Optional,SetimportnumpyasnpimportregexfromloguruimportloggerfrompydanticimportPositiveIntfromdata_juicer.utils.constantimportHashKeysfromdata_juicer.utils.lazy_loaderimportLazyLoaderfrom..base_opimportOPERATORS,Deduplicatorfrom..common.helper_funcimportsplit_on_whitespacesimhash=LazyLoader('simhash','simhash-pybind')OP_NAME='document_simhash_deduplicator'
[docs]@OPERATORS.register_module(OP_NAME)classDocumentSimhashDeduplicator(Deduplicator):"""Deduplicator to deduplicate samples at document-level using SimHash."""
[docs]def__init__(self,tokenization:str='space',window_size:PositiveInt=6,lowercase:bool=True,ignore_pattern:Optional[str]=None,num_blocks:PositiveInt=6,hamming_distance:PositiveInt=4,*args,**kwargs):""" Initialization method :param tokenization: tokenization method for sample texts. It should be one of [space, punctuation, character]. For English-like languages, we recommend to use 'space'. And for Chinese-like languages, we recommend to use 'character' :param window_size: window size of shingling :param lowercase: whether to convert text to lower case first :param ignore_pattern: whether to ignore sub-strings with specific pattern when computing simhash :param num_blocks: number of blocks in simhash computing :param hamming_distance: the max hamming distance threshold in near-duplicate detection. When the hamming distance of two sample texts is <= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication. This threshold should be always less than num_blocks """# about simhash computationsuper().__init__(*args,**kwargs)self.tokenization=tokenizationself.window_size=window_sizeself.lowercase=lowercaseself.ignore_pattern=ignore_patternifself.ignore_pattern:self.ignore_pattern=regex.compile(self.ignore_pattern)# check parametersifself.ignore_patternandself.tokenization=='punctuation':logger.warning('Be careful that tokenization with punctuations ''won\'t work if the ignore pattern includes ''punctuations.')self.punctuation_pattern=regex.compile(r'\p{P}')# about deduplicationself.num_blocks=num_blocksself.hamming_distance=hamming_distance
[docs]defcompute_hash(self,sample):""" Compute simhash values for the sample. :param sample: input sample :return: sample with simhash value. """# check if it's computed alreadyifHashKeys.simhashinsample:returnsampletext=sample[self.text_key]ifself.lowercase:text=text.lower()ifself.ignore_pattern:text=self.ignore_pattern.sub('',text)# get tokens for different tokenization methodtokens=[]ifself.tokenization=='character':tokens=[str.encode(text[i:i+self.window_size])foriinrange(len(text)-self.window_size+1)]elifself.tokenization=='punctuation':tokens=self.punctuation_pattern.split(text)tokens=[str.encode(' '.join(tokens[i:i+self.window_size]))foriinrange(len(tokens)-self.window_size+1)]elifself.tokenization=='space':tokens=split_on_whitespace(text)tokens=[str.encode(' '.join(tokens[i:i+self.window_size]))foriinrange(len(tokens)-self.window_size+1)]else:raiseNotImplementedError(f'Unimplemented tokenization method [{self.tokenization}]')# compute simhashsample[HashKeys.simhash]=str(np.uint64(simhash.compute(map(simhash.unsigned_hash,tokens))))returnsample
[docs]defprocess(self,dataset,show_num=0):""" For doc-level, dataset --> dataset. :param dataset: input dataset :param show_num: number of traced samples used when tracer is open. :return: deduplicated dataset and the sampled duplicate pairs. """# no need to deduplicate because too few samplesiflen(dataset)<=1:returndataset,{}# find matcheslogger.info(f'Start querying {len(dataset)} samples.')matches=simhash.find_all(np.uint64(dataset[HashKeys.simhash]),self.num_blocks,self.hamming_distance,)logger.info(f'Querying done, found {len(matches)} matches.')# compute hash diff distributiongraph=defaultdict(dict)forx,yinmatches:x=str(x)y=str(y)graph[x][y]=graph[y][x]=Truehash2ids:Dict[str,Set[str]]=defaultdict(set)hashes:Set[str]=set(dataset[HashKeys.simhash])hash2cluster:Dict[str,int]={}visited:Set[str]=set()cluster_id:int=0forsid,hash_valinenumerate(dataset[HashKeys.simhash]):hash2ids[hash_val].add(str(sid))# clusteringdup_pairs={}# store duplicate pairs when show_num > 0whilehashes:hash_val=hashes.pop()ifhash_valinvisited:continue# if this hash value is not in the matches list, it's regarded as a# single clusterifhash_valnotingraph:continue# Otherwise, BFS to find the clusterq=deque([hash_val])visited.add(hash_val)hash2cluster[hash_val]=cluster_idifshow_num>0andlen(dup_pairs)<show_num:dup_pairs[cluster_id]=[]whileq:curr=q.popleft()forneighboringraph[curr]:ifneighborinvisited:continuevisited.add(neighbor)q.append(neighbor)hash2cluster[neighbor]=cluster_idcluster_id+=1logger.info(f'Found {cluster_id} clusters and {len(graph)} hashes.')# filter duplicated samples# NOTICE: For now, we only keep the first sample in a cluster. Maybe# there are some better strategies later.def_filter_simhash_dup_helper(sample,visited_clusters,visited_hashes):sample_hash_val=sample[HashKeys.simhash]ifsample_hash_valnotinhash2cluster:# single-sample cluster, we need to check hash value still.ifsample_hash_valinvisited_hashes:returnFalseelse:visited_hashes.add(sample_hash_val)returnTrueelse:cluster_num=hash2cluster[sample_hash_val]ifshow_num>0andcluster_numindup_pairs \
andlen(dup_pairs[cluster_num])<2:dup_pairs[cluster_num].append(sample)# regular cluster, check cluster number.ifcluster_numinvisited_clusters:returnFalseelse:visited_clusters.add(cluster_num)returnTruecluster_record=set()hash_record=set()dataset=dataset.filter(_filter_simhash_dup_helper,fn_kwargs=dict(visited_clusters=cluster_record,visited_hashes=hash_record),load_from_cache_file=Falseifshow_num>0elseTrue)logger.info(f'Keep {len(dataset)} samples after SimHash dedup.')returndataset,dup_pairs