from abc import ABC, abstractmethod
import ray
from data_juicer.utils.constant import HashKeys
from data_juicer.utils.lazy_loader import LazyLoader
from ..base_op import Filter
redis = LazyLoader('redis', 'redis')
MERSENNE_PRIME = (1 << 61) - 1
@ray.remote(scheduling_strategy='SPREAD')
class DedupSet:
def __init__(self):
self.hash_record = set()
def is_unique(self, key):
if key not in self.hash_record:
self.hash_record.add(key)
return True
else:
return False
[docs]
class Backend(ABC):
"""
Backend for deduplicator.
"""
[docs]
@abstractmethod
def __init__(self, *args, **kwargs):
pass
[docs]
@abstractmethod
def is_unique(self, md5_value: str):
pass
[docs]
class ActorBackend(Backend):
"""
Ray actor backend for deduplicator.
"""
[docs]
def __init__(self, dedup_set_num: int):
self.dedup_set_num = dedup_set_num
self.dedup_sets = [
DedupSet.remote() for _ in range(self.dedup_set_num)
]
[docs]
def is_unique(self, md5_value: str):
dedup_set_id = int.from_bytes(
md5_value.encode(),
byteorder='little') % MERSENNE_PRIME % self.dedup_set_num
return ray.get(
self.dedup_sets[dedup_set_id].is_unique.remote(md5_value))
[docs]
class RedisBackend(Backend):
"""
Redis backend for deduplicator.
"""
[docs]
def __init__(self, redis_address: str):
self.redis_address = redis_address
self.redis_client = redis.from_url(url=self.redis_address)
self.redis_client.flushdb(0)
[docs]
def is_unique(self, md5_value: str):
return self.redis_client.setnx(md5_value, 1)
[docs]
class RayBasicDeduplicator(Filter):
"""
A basic exact matching deduplicator for RAY.
Although its functionality is deduplication,
it is implemented as Filter sub-class.
"""
# TODO: Set a more reasonable value
EMPTY_HASH_VALUE = 'EMPTY'
[docs]
def __init__(self,
backend: str = 'ray_actor',
redis_address: str = 'redis://localhost:6379',
*args,
**kwargs):
"""
Initialization.
:param backend: the backend for dedup, either 'ray_actor' or 'redis'
:param redis_address: the address of redis server
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.redis_address = redis_address
self.backend = backend
if backend == 'ray_actor':
dedup_set_num = int(ray.cluster_resources().get('CPU') / 2)
self.backend = ActorBackend(dedup_set_num)
elif backend == 'redis':
# TODO: add a barrier to ensure that flushdb is performed before
# the operator is called
self.backend = RedisBackend(redis_address)
else:
raise ValueError(f'Unknown backend: {backend}')
[docs]
def calculate_hash(self, sample, context=False):
"""Calculate hash value for the sample."""
raise NotImplementedError
[docs]
def compute_stats_single(self, sample, context=False):
# compute hash
md5_value = self.calculate_hash(sample, context)
# check existing
sample[HashKeys.is_unique] = self.backend.is_unique(md5_value)
return sample
[docs]
def process_single(self, sample):
return sample[HashKeys.is_unique]