from collections import defaultdict
from typing import Dict, Set, Tuple
import numpy as np
from data_juicer.utils.constant import HashKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.mm_utils import load_data_with_context, load_image
from ..base_op import OPERATORS, Deduplicator
from ..op_fusion import LOADED_IMAGES
from .document_deduplicator import DocumentDeduplicator
imgdedup_methods = LazyLoader('imgdedup_methods', 'imagededup.methods')
OP_NAME = 'image_deduplicator'
HASH_METHOD = {'phash', 'dhash', 'whash', 'ahash'}
[docs]
def get_hash_method(method_name):
mapping = {
'phash': imgdedup_methods.PHash,
'dhash': imgdedup_methods.DHash,
'whash': imgdedup_methods.WHash,
'ahash': imgdedup_methods.AHash
}
return mapping[method_name]
[docs]
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageDeduplicator(Deduplicator):
"""
Deduplicator to deduplicate samples at document-level using exact matching
of images between documents.
"""
[docs]
def __init__(self,
method: str = 'phash',
consider_text: bool = False,
*args,
**kwargs):
"""
Initialization method.
:param method: hash method for image
:param consider_text: whether to consider text hash together with image
hash when applying deduplication.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
if method not in HASH_METHOD:
raise ValueError(f'Keep strategy [{method}] is not supported. '
f'Can only be one of {HASH_METHOD}.')
self.hasher = get_hash_method(method)()
self.consider_text = consider_text
self.text_dedup_op = None
if self.consider_text:
self.text_dedup_op = DocumentDeduplicator(**kwargs)
[docs]
def compute_hash(self, sample, context=False):
# get hash of text first
if self.consider_text:
sample = self.text_dedup_op.compute_hash(sample)
# check if it's computed already
if HashKeys.imagehash in sample:
return sample
# there is no image in this sample
sample[HashKeys.imagehash] = ''
if self.image_key not in sample or not sample[self.image_key]:
return sample
# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)
# compute hash
for key in images:
sample[HashKeys.imagehash] += self.hasher.encode_image(
image_array=np.array(images[key]))
return sample
[docs]
def process(self, dataset, show_num=0):
"""
For doc-level, dataset --> dataset.
:param dataset: input dataset
:param show_num: number of traced samples used when tracer is
open.
:return: deduplicated dataset and the sampled duplicate pairs.
"""
# no need to deduplicate because too few samples
if len(dataset) <= 1:
return dataset, {}
dup_hashes = None
if show_num > 0:
# sample duplicate pairs
if self.consider_text:
hash2ids: Dict[Tuple[int, int], Set[int]] = defaultdict(set)
hashes = zip(dataset[HashKeys.imagehash],
dataset[HashKeys.hash])
else:
hash2ids: Dict[int, Set[int]] = defaultdict(set)
hashes = dataset[HashKeys.imagehash]
for sid, hash_val in enumerate(hashes):
if hash_val:
hash2ids[hash_val].add(sid)
dup_samples = sorted(list(hash2ids.items()),
key=lambda x: len(x[1]),
reverse=True)
dup_hashes = set([
item[0] for item in dup_samples if len(item[1]) > 1
][:show_num])
def _filter_dup_helper(sample, hashes):
if self.consider_text:
hash = (sample[HashKeys.imagehash], sample[HashKeys.hash])
else:
hash = sample[HashKeys.imagehash]
if not hash:
return True
if show_num > 0 and hash in dup_hashes \
and len(dup_pairs[hash]) < 2:
# tracer is open and not enough duplicate sample pairs
dup_pairs[hash].append(sample)
if hash in hashes:
return False
else:
hashes.add(hash)
return True
hashes = set()
dup_pairs = {hash_v: [] for hash_v in dup_hashes} if dup_hashes else {}
dataset = dataset.filter(
_filter_dup_helper,
fn_kwargs=dict(hashes=hashes),
load_from_cache_file=False if show_num > 0 else True) # num_proc=1
return dataset, dup_pairs