import os
from time import time
from typing import Optional
from jsonargparse import Namespace
from loguru import logger
from pydantic import PositiveInt
from data_juicer.config import init_configs
from data_juicer.core.data import Dataset
from data_juicer.format.load import load_formatter
from data_juicer.format.mixture_formatter import MixtureFormatter
from data_juicer.ops import OPERATORS, load_ops
from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.utils import cache_utils
from data_juicer.utils.ckpt_utils import CheckpointManager
from ..ops.selector.frequency_specified_field_selector import \
FrequencySpecifiedFieldSelector
from ..ops.selector.topk_specified_field_selector import \
TopkSpecifiedFieldSelector
from .adapter import Adapter
from .exporter import Exporter
from .tracer import Tracer
[docs]
class Executor:
"""
This Executor class is used to process a specific dataset.
It will load the dataset and unify the format, then apply all the
ops in the config file in order and generate a processed dataset.
"""
[docs]
def __init__(self, cfg: Optional[Namespace] = None):
"""
Initialization method.
:param cfg: optional jsonargparse Namespace.
"""
self.cfg = init_configs() if cfg is None else cfg
self.work_dir = self.cfg.work_dir
self.tracer = None
self.ckpt_manager = None
self.adapter = Adapter(self.cfg)
# only enable it when using cache
if self.cfg.use_cache:
logger.info(f'Using cache compression method: '
f'[{self.cfg.cache_compress}]')
cache_utils.CACHE_COMPRESS = self.cfg.cache_compress
# setup formatter
logger.info('Setting up data formatter...')
self.formatter = load_formatter(
dataset_path=self.cfg.dataset_path,
generated_dataset_config=self.cfg.generated_dataset_config,
text_keys=self.cfg.text_keys,
suffixes=self.cfg.suffixes,
add_suffix=self.cfg.add_suffix)
# whether to use checkpoint mechanism. If it's true, Executor will
# check if there are existing checkpoints first and try to load the
# checkpoints. If the checkpoints are loaded successfully, ops that
# have been processed will be skipped.
if self.cfg.use_checkpoint:
logger.info('Preparing checkpoint manager...')
self.ckpt_dir = os.path.join(self.work_dir, 'ckpt')
self.ckpt_manager = CheckpointManager(self.ckpt_dir,
self.cfg.process,
self.cfg.np)
if self.ckpt_manager.ckpt_available:
logger.info('Found existed dataset checkpoint.')
self.cfg.process = self.ckpt_manager.get_left_process_list()
# prepare exporter and check export path suffix
logger.info('Preparing exporter...')
self.exporter = Exporter(
self.cfg.export_path,
self.cfg.export_shard_size,
self.cfg.export_in_parallel,
self.cfg.np,
keep_stats_in_res_ds=self.cfg.keep_stats_in_res_ds,
keep_hashes_in_res_ds=self.cfg.keep_hashes_in_res_ds)
# setup tracer
self.open_tracer = self.cfg.open_tracer
if self.open_tracer:
logger.info('Preparing tracer...')
self.tracer = Tracer(self.work_dir, show_num=self.cfg.trace_num)
self.op_list_to_trace = self.cfg.op_list_to_trace
if len(self.cfg.op_list_to_trace) == 0:
logger.info('Trace for all ops.')
self.op_list_to_trace = set(OPERATORS.modules.keys())
[docs]
def sample_data(self,
dataset_to_sample: Dataset = None,
load_data_np=None,
sample_ratio: float = 1.0,
sample_algo: str = 'uniform',
**kwargs):
"""
Sample a subset from the given dataset.
:param dataset_to_sample: Dataset to sample from. If None, will use
the formatter linked by the executor. Default is None.
:param load_data_np: number of workers when loading the dataset.
:param sample_ratio: The ratio of the sample size to the original
dataset size. Default is 1.0 (no sampling).
:param sample_algo: Sampling algorithm to use. Options are "uniform",
"frequency_specified_field_selector", or
"topk_specified_field_selector".
Default is "uniform".
:return: A sampled Dataset.
"""
# Determine the dataset to sample from
if dataset_to_sample is not None:
dataset = dataset_to_sample
elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available:
logger.info('Loading dataset from checkpoint...')
dataset = self.ckpt_manager.load_ckpt()
elif hasattr(self, 'formatter'):
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np, self.cfg)
else:
raise ValueError('No dataset available to sample from.')
# Perform sampling based on the specified algorithm
if sample_algo == 'uniform':
return MixtureFormatter.random_sample(dataset, sample_ratio)
elif sample_algo == 'frequency_specified_field_selector':
dj_op = FrequencySpecifiedFieldSelector(**kwargs)
return dj_op.process(dataset)
elif sample_algo == 'topk_specified_field_selector':
dj_op = TopkSpecifiedFieldSelector(**kwargs)
return dj_op.process(dataset)
else:
raise ValueError(f'Unsupported sample_algo: {sample_algo}')
[docs]
def run(self,
load_data_np: Optional[PositiveInt] = None,
skip_return=False):
"""
Running the dataset process pipeline.
:param load_data_np: number of workers when loading the dataset.
:param skip_return: skip return for API called.
:return: processed dataset.
"""
# 1. format data
if self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available:
logger.info('Loading dataset from checkpoint...')
dataset = self.ckpt_manager.load_ckpt()
else:
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np, self.cfg)
# 2. extract processes and optimize their orders
logger.info('Preparing process operators...')
ops = load_ops(self.cfg.process)
# OP fusion
if self.cfg.op_fusion:
probe_res = None
if self.cfg.fusion_strategy == 'probe':
logger.info('Probe the OP speed for OP reordering...')
probe_res, _ = self.adapter.probe_small_batch(dataset, ops)
logger.info(f'Start OP fusion and reordering with strategy '
f'[{self.cfg.fusion_strategy}]...')
ops = fuse_operators(ops, probe_res)
# adaptive batch size
if self.cfg.adaptive_batch_size:
# calculate the adaptive batch size
bs_per_op = self.adapter.adapt_workloads(dataset, ops)
assert len(bs_per_op) == len(ops)
# update the adaptive batch size
logger.info(f'Adapt batch sizes for each OP to {bs_per_op}')
for i, op in enumerate(ops):
if op.is_batched_op():
op.batch_size = bs_per_op[i]
# 3. data process
# - If tracer is open, trace each op after it's processed
# - If checkpoint is open, clean the cache files after each process
logger.info('Processing data...')
tstart = time()
dataset = dataset.process(
ops,
work_dir=self.work_dir,
exporter=self.exporter,
checkpointer=self.ckpt_manager,
tracer=self.tracer,
adapter=self.adapter,
open_monitor=self.cfg.open_monitor,
)
tend = time()
logger.info(f'All OPs are done in {tend - tstart:.3f}s.')
# 4. data export
logger.info('Exporting dataset to disk...')
self.exporter.export(dataset)
# compress the last dataset after exporting
if self.cfg.use_cache and self.cfg.cache_compress:
from data_juicer.utils.compress import compress
compress(dataset)
if not skip_return:
return dataset