Source code for data_juicer.ops.base_op

import copy
import traceback
from functools import wraps

import numpy as np
import pyarrow as pa
from loguru import logger

from data_juicer import is_cuda_available
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import size_to_bytes
from data_juicer.utils.process_utils import calculate_np
from data_juicer.utils.registry import Registry

OPERATORS = Registry('Operators')
UNFORKABLE = Registry('Unforkable')


def convert_list_dict_to_dict_list(samples):
    # reconstruct samples from "list of dicts" to "dict of lists"
    keys = samples[0].keys()
    res_samples = {}
    for key in keys:
        res_samples[key] = [s[key] for s in samples]
    return res_samples


def convert_dict_list_to_list_dict(samples):
    # reconstruct samples from "dict of lists" to "list of dicts"
    reconstructed_samples = []
    keys = list(samples.keys())
    # take any key, since they should be of same length
    for i in range(len(samples[keys[0]])):
        reconstructed_samples.append({key: samples[key][i] for key in samples})
    return reconstructed_samples


def convert_arrow_to_python(method):

    @wraps(method)
    def wrapper(sample, *args, **kwargs):
        if isinstance(sample, pa.Table):
            sample = sample.to_pydict()
        return method(sample, *args, **kwargs)

    return wrapper


def catch_map_batches_exception(method):
    """
    For batched-map sample-level fault tolerance.
    """

    @wraps(method)
    @convert_arrow_to_python
    def wrapper(samples, *args, **kwargs):
        try:
            return method(samples, *args, **kwargs)
        except Exception as e:
            from loguru import logger
            logger.error(
                f'An error occurred in mapper operation when processing '
                f'samples {samples}, {type(e)}: {e}')
            traceback.print_exc()
            ret = {key: [] for key in samples.keys()}
            ret[Fields.stats] = []
            ret[Fields.source_file] = []
            return ret

    return wrapper


def catch_map_single_exception(method):
    """
    For single-map sample-level fault tolerance.
    The input sample is expected batch_size = 1.
    """

    def is_batched(sample):
        val_iter = iter(sample.values())
        first_val = next(val_iter)
        if not isinstance(first_val, list):
            return False
        first_len = len(first_val)
        return all(
            isinstance(val, list) and len(val) == first_len
            for val in val_iter)

    @wraps(method)
    @convert_arrow_to_python
    def wrapper(sample, *args, **kwargs):
        if is_batched(sample):
            try:
                sample = convert_dict_list_to_list_dict(sample)[0]
                res_sample = method(sample, *args, **kwargs)
                return convert_list_dict_to_dict_list([res_sample])
            except Exception as e:
                from loguru import logger
                logger.error(
                    f'An error occurred in mapper operation when processing '
                    f'sample {sample}, {type(e)}: {e}')
                traceback.print_exc()
                ret = {key: [] for key in sample.keys()}
                ret[Fields.stats] = []
                ret[Fields.source_file] = []
                return ret
        else:
            # without fault tolerance
            return method(sample, *args, **kwargs)

    return wrapper


class OP:

    _accelerator = 'cpu'
    _batched_op = False

    def __init__(self, *args, **kwargs):
        """
        Base class of operators.

        :param text_key: the key name of field that stores sample texts
            to be processed.
        :param image_key: the key name of field that stores sample image list
            to be processed
        :param audio_key: the key name of field that stores sample audio list
            to be processed
        :param video_key: the key name of field that stores sample video list
            to be processed
        """
        # init data keys
        self.text_key = kwargs.get('text_key', 'text')
        self.image_key = kwargs.get('image_key', 'images')
        self.audio_key = kwargs.get('audio_key', 'audios')
        self.video_key = kwargs.get('video_key', 'videos')

        self.query_key = kwargs.get('query_key', 'query')
        self.response_key = kwargs.get('response_key', 'response')
        self.history_key = kwargs.get('history_key', 'history')

        self.batch_size = kwargs.get('batch_size', 1000)

        # whether the model can be accelerated using cuda
        _accelerator = kwargs.get('accelerator', None)
        if _accelerator is not None:
            self.accelerator = _accelerator
        else:
            self.accelerator = self._accelerator

        # parameters to determind the number of procs for this op
        self.num_proc = kwargs.get('num_proc', None)
        self.cpu_required = kwargs.get('cpu_required', 1)
        self.mem_required = kwargs.get('mem_required', 0)
        if isinstance(self.mem_required, str):
            self.mem_required = size_to_bytes(self.mem_required) / 1024**3

        self.turbo = kwargs.get('turbo', False)

        # nested wrappers
        from data_juicer.core.data import wrap_func_with_nested_access
        for name in ['process', 'compute_stats', 'compute_hash']:
            method = getattr(self, name, None)
            if method and callable(method):
                setattr(self, f'_{name}', method)
                method = wrap_func_with_nested_access(method)
                setattr(self, name, method)

    @classmethod
    def is_batched_op(cls):
        return cls._batched_op

    def process(self, *args, **kwargs):
        raise NotImplementedError

    def use_cuda(self):
        return self.accelerator == 'cuda' and is_cuda_available()

    def runtime_np(self):
        op_proc = calculate_np(self._name, self.mem_required,
                               self.cpu_required, self.num_proc,
                               self.use_cuda())
        logger.debug(
            f'Op [{self._name}] running with number of procs:{op_proc}')
        return op_proc

    def remove_extra_parameters(self, param_dict, keys=None):
        """
            at the begining of the init of the mapper op, call
            self.remove_extra_parameters(locals())
            to get the init parameter dict of the op for convenience

        """
        if keys is None:
            param_dict = {
                k: v
                for k, v in param_dict.items() if not k.startswith('_')
            }
            param_dict.pop('self', None)
        else:
            param_dict = {k: v for k, v in param_dict.items() if k not in keys}
        return param_dict

    def add_parameters(self, init_parameter_dict, **extra_param_dict):
        """
            add parameters for each sample, need to keep extra_param_dict
            and init_parameter_dict unchanged.
        """
        related_parameters = copy.deepcopy(init_parameter_dict)
        related_parameters.update(extra_param_dict)
        return related_parameters

    def run(self, dataset):
        from data_juicer.core.data import NestedDataset
        if not isinstance(dataset, NestedDataset):
            dataset = NestedDataset(dataset)
        return dataset

    def empty_history(self):
        return np.empty((0, 0), dtype=str)


[docs]class Mapper(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that conducts data editing. :param text_key: the key name of field that stores sample texts to be processed. :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed """ super(Mapper, self).__init__(*args, **kwargs) # runtime wrappers if self.is_batched_op(): self.process = catch_map_batches_exception(self.process_batched) else: self.process = catch_map_single_exception(self.process_single)
# set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): not_allowed_list = ['process'] for method_name in not_allowed_list: if method_name in cls.__dict__: raise TypeError( f'Method {method_name} cannot be overridden by subclass ' f'{cls.__name__}. Please implement {method_name}_single ' f'or {method_name}_batched.')
[docs] def process_batched(self, samples, *args, **kwargs): keys = samples.keys() first_key = next(iter(keys)) num_samples = len(samples[first_key]) for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} res_sample = self.process_single(this_sample, *args, **kwargs) for key in keys: samples[key][i] = res_sample[key] return samples
[docs] def process_single(self, sample): """ For sample level, sample --> sample :param sample: sample to process :return: processed sample """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Mapper, self).run(dataset) new_dataset = dataset.map( self.process, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, desc=self._name + '_process', ) if tracer: tracer.trace_mapper(self._name, dataset, new_dataset, self.text_key) return new_dataset
[docs]class Filter(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that removes specific info. :param text_key: the key name of field that stores sample texts to be processed :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed """ super(Filter, self).__init__(*args, **kwargs) self.stats_export_path = kwargs.get('stats_export_path', None) # runtime wrappers if self.is_batched_op(): self.compute_stats = catch_map_batches_exception( self.compute_stats_batched) self.process = catch_map_batches_exception(self.process_batched) else: self.compute_stats = catch_map_single_exception( self.compute_stats_single) self.process = catch_map_single_exception(self.process_single)
# set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): not_allowed_list = ['compute_stats', 'process'] for method_name in not_allowed_list: if method_name in cls.__dict__: raise TypeError( f'Method {method_name} cannot be overridden by subclass ' f'{cls.__name__}. Please implement {method_name}_single ' f'or {method_name}_batched.')
[docs] def compute_stats_batched(self, samples, *args, **kwargs): keys = samples.keys() num_samples = len(samples[Fields.stats]) for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} res_sample = self.compute_stats_single(this_sample, *args, **kwargs) samples[Fields.stats][i] = res_sample[Fields.stats] if 'context' in kwargs and kwargs['context']: samples[Fields.context][i] = res_sample[Fields.context] return samples
[docs] def process_batched(self, samples): return map(lambda stat: self.process_single({Fields.stats: stat}), samples[Fields.stats])
[docs] def compute_stats_single(self, sample, context=False): """ Compute stats for the sample which is used as a metric to decide whether to filter this sample. :param sample: input sample. :param context: whether to store context information of intermediate vars in the sample temporarily. :return: sample with computed stats """ raise NotImplementedError
[docs] def process_single(self, sample): """ For sample level, sample --> Boolean. :param sample: sample to decide whether to filter :return: true for keeping and false for filtering """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None, reduce=True): dataset = super(Filter, self).run(dataset) if Fields.stats not in dataset.features: from data_juicer.core.data import add_same_content_to_new_column dataset = dataset.map(add_same_content_to_new_column, fn_kwargs={ 'new_column_name': Fields.stats, 'initial_value': {} }, num_proc=self.runtime_np(), batch_size=self.batch_size, desc='Adding new column for stats') dataset = dataset.map(self.compute_stats, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, desc=self._name + '_compute_stats') if exporter and self.stats_export_path is not None: exporter.export_compute_stats(dataset, self.stats_export_path) if reduce: new_dataset = dataset.filter(self.process, num_proc=self.runtime_np(), batch_size=self.batch_size, desc=self._name + '_process') if tracer: tracer.trace_filter(self._name, dataset, new_dataset) return new_dataset else: return dataset
[docs]class Deduplicator(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that conducts deduplication. :param text_key: the key name of field that stores sample texts to be processed :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed """ super(Deduplicator, self).__init__(*args, **kwargs) # runtime wrappers if self.is_batched_op(): self.compute_hash = catch_map_batches_exception(self.compute_hash) else: self.compute_hash = catch_map_single_exception(self.compute_hash)
[docs] def compute_hash(self, sample): """ Compute hash values for the sample. :param sample: input sample :return: sample with computed hash value. """ raise NotImplementedError
[docs] def process(self, dataset, show_num=0): """ For doc-level, dataset --> dataset. :param dataset: input dataset :param show_num: number of traced samples used when tracer is open. :return: deduplicated dataset and the sampled duplicate pairs. """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None, reduce=True): dataset = super(Deduplicator, self).run(dataset) dataset = dataset.map(self.compute_hash, num_proc=self.runtime_np(), with_rank=self.use_cuda(), desc=self._name + '_compute_hash') if reduce: show_num = tracer.show_num if tracer else 0 new_dataset, dup_pairs = self.process(dataset, show_num) if tracer: tracer.trace_deduplicator(self._name, dup_pairs) return new_dataset else: return dataset
[docs]class Selector(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that conducts selection in dataset-level. :param text_key: the key name of field that stores sample texts to be processed :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed """ super(Selector, self).__init__(*args, **kwargs)
[docs] def process(self, dataset): """ Dataset --> dataset. :param dataset: input dataset :return: selected dataset. """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Selector, self).run(dataset) new_dataset = self.process(dataset) if tracer: tracer.trace_filter(self._name, dataset, new_dataset) return new_dataset