Source code for data_juicer.ops.base_op

import copy
import traceback
from functools import wraps

import numpy as np
import pyarrow as pa
from loguru import logger

from data_juicer import is_cuda_available
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import size_to_bytes
from data_juicer.utils.process_utils import calculate_np
from data_juicer.utils.registry import Registry

OPERATORS = Registry('Operators')
UNFORKABLE = Registry('Unforkable')
NON_STATS_FILTERS = Registry('Non-stats Filters')
TAGGING_OPS = Registry('Tagging Operators')


[docs] def convert_list_dict_to_dict_list(samples): # reconstruct samples from "list of dicts" to "dict of lists" keys = samples[0].keys() res_samples = {} for key in keys: res_samples[key] = [s[key] for s in samples] return res_samples
[docs] def convert_dict_list_to_list_dict(samples): # reconstruct samples from "dict of lists" to "list of dicts" reconstructed_samples = [] keys = list(samples.keys()) # take any key, since they should be of same length for i in range(len(samples[keys[0]])): reconstructed_samples.append({key: samples[key][i] for key in samples}) return reconstructed_samples
[docs] def convert_arrow_to_python(method): @wraps(method) def wrapper(sample, *args, **kwargs): if isinstance(sample, pa.Table): sample = sample.to_pydict() return method(sample, *args, **kwargs) return wrapper
[docs] def catch_map_batches_exception(method): """ For batched-map sample-level fault tolerance. """ @wraps(method) @convert_arrow_to_python def wrapper(samples, *args, **kwargs): try: return method(samples, *args, **kwargs) except Exception as e: from loguru import logger logger.error( f'An error occurred in mapper operation when processing ' f'samples {samples}, {type(e)}: {e}') traceback.print_exc() ret = {key: [] for key in samples.keys()} ret[Fields.stats] = [] ret[Fields.source_file] = [] return ret return wrapper
[docs] def catch_map_single_exception(method, return_sample=True): """ For single-map sample-level fault tolerance. The input sample is expected batch_size = 1. """ def is_batched(sample): val_iter = iter(sample.values()) first_val = next(val_iter) if not isinstance(first_val, list): return False first_len = len(first_val) return all( isinstance(val, list) and len(val) == first_len for val in val_iter) @wraps(method) @convert_arrow_to_python def wrapper(sample, *args, **kwargs): if is_batched(sample): try: sample = convert_dict_list_to_list_dict(sample)[0] res = method(sample, *args, **kwargs) if return_sample: return convert_list_dict_to_dict_list([res]) else: return [res] except Exception as e: from loguru import logger logger.error( f'An error occurred in mapper operation when processing ' f'sample {sample}, {type(e)}: {e}') traceback.print_exc() ret = {key: [] for key in sample.keys()} ret[Fields.stats] = [] ret[Fields.source_file] = [] return ret else: # without fault tolerance return method(sample, *args, **kwargs) return wrapper
[docs] class OP: _accelerator = 'cpu' _batched_op = False
[docs] def __init__(self, *args, **kwargs): """ Base class of operators. :param text_key: the key name of field that stores sample texts to be processed. :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed :param query_key: the key name of field that stores sample queris :param response_key: the key name of field that stores responses :param history_key: the key name of field that stores history of queries and responses :param index_key: index the samples before process if not None """ # init data keys self.text_key = kwargs.get('text_key', 'text') self.image_key = kwargs.get('image_key', 'images') self.audio_key = kwargs.get('audio_key', 'audios') self.video_key = kwargs.get('video_key', 'videos') self.query_key = kwargs.get('query_key', 'query') self.response_key = kwargs.get('response_key', 'response') self.history_key = kwargs.get('history_key', 'history') self.index_key = kwargs.get('index_key', None) self.batch_size = kwargs.get('batch_size', 1000) # whether the model can be accelerated using cuda _accelerator = kwargs.get('accelerator', None) if _accelerator is not None: self.accelerator = _accelerator else: self.accelerator = self._accelerator # parameters to determind the number of procs for this op self.num_proc = kwargs.get('num_proc', None) self.cpu_required = kwargs.get('cpu_required', 1) self.mem_required = kwargs.get('mem_required', 0) if isinstance(self.mem_required, str): self.mem_required = size_to_bytes(self.mem_required) / 1024**3 self.turbo = kwargs.get('turbo', False) # nested wrappers from data_juicer.core.data import wrap_func_with_nested_access for name in ['process', 'compute_stats', 'compute_hash']: method = getattr(self, name, None) if method and callable(method): setattr(self, f'_{name}', method) method = wrap_func_with_nested_access(method) setattr(self, name, method)
[docs] def is_batched_op(self): return self._batched_op
[docs] def process(self, *args, **kwargs): raise NotImplementedError
[docs] def use_cuda(self): return self.accelerator == 'cuda' and is_cuda_available()
[docs] def runtime_np(self): op_proc = calculate_np(self._name, self.mem_required, self.cpu_required, self.num_proc, self.use_cuda()) logger.debug( f'Op [{self._name}] running with number of procs:{op_proc}') return op_proc
[docs] def remove_extra_parameters(self, param_dict, keys=None): """ at the begining of the init of the mapper op, call self.remove_extra_parameters(locals()) to get the init parameter dict of the op for convenience """ if keys is None: param_dict = { k: v for k, v in param_dict.items() if not k.startswith('_') } param_dict.pop('self', None) else: param_dict = {k: v for k, v in param_dict.items() if k not in keys} return param_dict
[docs] def add_parameters(self, init_parameter_dict, **extra_param_dict): """ add parameters for each sample, need to keep extra_param_dict and init_parameter_dict unchanged. """ related_parameters = copy.deepcopy(init_parameter_dict) related_parameters.update(extra_param_dict) return related_parameters
[docs] def run(self, dataset): from data_juicer.core.data import NestedDataset if not isinstance(dataset, NestedDataset): dataset = NestedDataset(dataset) # add meta field for OPs that produce tags if self._name in TAGGING_OPS.modules \ and Fields.meta not in dataset.features: from data_juicer.core.data import add_same_content_to_new_column dataset = dataset.map(add_same_content_to_new_column, fn_kwargs={ 'new_column_name': Fields.meta, 'initial_value': {} }, num_proc=self.runtime_np(), batch_size=self.batch_size, desc='Adding new column for meta') if self.index_key is not None: def add_index(sample, idx): sample[self.index_key] = idx return sample dataset = dataset.map(add_index, with_indices=True) return dataset
[docs] def empty_history(self): return np.empty((0, 0), dtype=str)
[docs] class Mapper(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that conducts data editing. :param text_key: the key name of field that stores sample texts to be processed. :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed :param query_key: the key name of field that stores sample queris :param response_key: the key name of field that stores responses :param history_key: the key name of field that stores history of queries and responses """ super(Mapper, self).__init__(*args, **kwargs) # runtime wrappers if self.is_batched_op(): self.process = catch_map_batches_exception(self.process_batched) else: self.process = catch_map_single_exception(self.process_single)
# set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): not_allowed_list = ['process'] for method_name in not_allowed_list: if method_name in cls.__dict__: raise TypeError( f'Method {method_name} cannot be overridden by subclass ' f'{cls.__name__}. Please implement {method_name}_single ' f'or {method_name}_batched.')
[docs] def process_batched(self, samples, *args, **kwargs): keys = samples.keys() first_key = next(iter(keys)) num_samples = len(samples[first_key]) new_keys = {} for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} res_sample = self.process_single(this_sample, *args, **kwargs) res_keys = res_sample.keys() for key in res_keys: if key not in keys: if key not in new_keys: new_keys.update({key: []}) new_keys[key].append(res_sample[key]) else: samples[key][i] = res_sample[key] for k, v in new_keys.items(): samples[k] = v return samples
[docs] def process_single(self, sample): """ For sample level, sample --> sample :param sample: sample to process :return: processed sample """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Mapper, self).run(dataset) new_dataset = dataset.map( self.process, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, desc=self._name + '_process', ) if tracer: tracer.trace_mapper(self._name, dataset, new_dataset, self.text_key) return new_dataset
[docs] class Filter(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that removes specific info. :param text_key: the key name of field that stores sample texts to be processed :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed :param query_key: the key name of field that stores sample queris :param response_key: the key name of field that stores responses :param history_key: the key name of field that stores history of queries and responses """ super(Filter, self).__init__(*args, **kwargs) self.stats_export_path = kwargs.get('stats_export_path', None) # runtime wrappers if self.is_batched_op(): self.compute_stats = catch_map_batches_exception( self.compute_stats_batched) self.process = catch_map_batches_exception(self.process_batched) else: self.compute_stats = catch_map_single_exception( self.compute_stats_single) self.process = catch_map_single_exception(self.process_single, return_sample=False)
# set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): not_allowed_list = ['compute_stats', 'process'] for method_name in not_allowed_list: if method_name in cls.__dict__: raise TypeError( f'Method {method_name} cannot be overridden by subclass ' f'{cls.__name__}. Please implement {method_name}_single ' f'or {method_name}_batched.')
[docs] def compute_stats_batched(self, samples, *args, **kwargs): keys = samples.keys() num_samples = len(samples[Fields.stats]) for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} res_sample = self.compute_stats_single(this_sample, *args, **kwargs) samples[Fields.stats][i] = res_sample[Fields.stats] if 'context' in kwargs and kwargs['context']: samples[Fields.context][i] = res_sample[Fields.context] return samples
[docs] def process_batched(self, samples): return map(lambda stat: self.process_single({Fields.stats: stat}), samples[Fields.stats])
[docs] def compute_stats_single(self, sample, context=False): """ Compute stats for the sample which is used as a metric to decide whether to filter this sample. :param sample: input sample. :param context: whether to store context information of intermediate vars in the sample temporarily. :return: sample with computed stats """ raise NotImplementedError
[docs] def process_single(self, sample): """ For sample level, sample --> Boolean. :param sample: sample to decide whether to filter :return: true for keeping and false for filtering """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None, reduce=True): dataset = super(Filter, self).run(dataset) # add stats field for Filters that produce stats if self._name not in NON_STATS_FILTERS.modules \ and Fields.stats not in dataset.features: from data_juicer.core.data import add_same_content_to_new_column dataset = dataset.map(add_same_content_to_new_column, fn_kwargs={ 'new_column_name': Fields.stats, 'initial_value': {} }, num_proc=self.runtime_np(), batch_size=self.batch_size, desc='Adding new column for stats') dataset = dataset.map(self.compute_stats, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, desc=self._name + '_compute_stats') if exporter and self.stats_export_path is not None: exporter.export_compute_stats(dataset, self.stats_export_path) if reduce: new_dataset = dataset.filter(self.process, num_proc=self.runtime_np(), batch_size=self.batch_size, desc=self._name + '_process') if tracer: tracer.trace_filter(self._name, dataset, new_dataset) return new_dataset else: return dataset
[docs] class Deduplicator(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that conducts deduplication. :param text_key: the key name of field that stores sample texts to be processed :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed :param query_key: the key name of field that stores sample queris :param response_key: the key name of field that stores responses :param history_key: the key name of field that stores history of queries and responses """ super(Deduplicator, self).__init__(*args, **kwargs) # runtime wrappers if self.is_batched_op(): self.compute_hash = catch_map_batches_exception(self.compute_hash) else: self.compute_hash = catch_map_single_exception(self.compute_hash)
[docs] def compute_hash(self, sample): """ Compute hash values for the sample. :param sample: input sample :return: sample with computed hash value. """ raise NotImplementedError
[docs] def process(self, dataset, show_num=0): """ For doc-level, dataset --> dataset. :param dataset: input dataset :param show_num: number of traced samples used when tracer is open. :return: deduplicated dataset and the sampled duplicate pairs. """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None, reduce=True): dataset = super(Deduplicator, self).run(dataset) dataset = dataset.map(self.compute_hash, num_proc=self.runtime_np(), with_rank=self.use_cuda(), desc=self._name + '_compute_hash') if reduce: show_num = tracer.show_num if tracer else 0 new_dataset, dup_pairs = self.process(dataset, show_num) if tracer: tracer.trace_deduplicator(self._name, dup_pairs) return new_dataset else: return dataset
[docs] class Selector(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that conducts selection in dataset-level. :param text_key: the key name of field that stores sample texts to be processed :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed :param query_key: the key name of field that stores sample queris :param response_key: the key name of field that stores responses :param history_key: the key name of field that stores history of queries and responses """ super(Selector, self).__init__(*args, **kwargs)
[docs] def process(self, dataset): """ Dataset --> dataset. :param dataset: input dataset :return: selected dataset. """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Selector, self).run(dataset) new_dataset = self.process(dataset) if tracer: tracer.trace_filter(self._name, dataset, new_dataset) return new_dataset
[docs] class Grouper(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that group samples. :param text_key: the key name of field that stores sample texts to be processed :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed :param query_key: the key name of field that stores sample queris :param response_key: the key name of field that stores responses :param history_key: the key name of field that stores history of queries and responses """ super(Grouper, self).__init__(*args, **kwargs)
[docs] def process(self, dataset): """ Dataset --> dataset. :param dataset: input dataset :return: dataset of batched samples. """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Grouper, self).run(dataset) batched_samples = self.process(dataset) from data_juicer.core.data import NestedDataset new_dataset = NestedDataset.from_list(batched_samples) if tracer: tracer.trace_filter(self._name, dataset, new_dataset) return new_dataset
[docs] class Aggregator(OP):
[docs] def __init__(self, *args, **kwargs): """ Base class that group samples. :param text_key: the key name of field that stores sample texts to be processed :param image_key: the key name of field that stores sample image list to be processed :param audio_key: the key name of field that stores sample audio list to be processed :param video_key: the key name of field that stores sample video list to be processed :param query_key: the key name of field that stores sample queris :param response_key: the key name of field that stores responses :param history_key: the key name of field that stores history of queries and responses """ super(Aggregator, self).__init__(*args, **kwargs) self.process = catch_map_single_exception(self.process_single)
[docs] def process_single(self, sample): """ For sample level, batched sample --> sample, the input must be the output of some Grouper OP. :param sample: batched sample to aggregate :return: aggregated sample """ raise NotImplementedError
[docs] def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Aggregator, self).run(dataset) new_dataset = dataset.map( self.process, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, desc=self._name + '_process', ) if tracer: tracer.trace_mapper(self._name, dataset, new_dataset, self.text_key) return new_dataset