Source code for data_juicer.ops.selector.frequency_specified_field_selector

import numbers
from typing import Optional

from pydantic import Field, PositiveInt
from typing_extensions import Annotated

from ..base_op import OPERATORS, Selector


[docs]@OPERATORS.register_module('frequency_specified_field_selector') class FrequencySpecifiedFieldSelector(Selector): """Selector to select samples based on the sorted frequency of specified field."""
[docs] def __init__(self, field_key: str = '', top_ratio: Optional[Annotated[float, Field(ge=0, le=1)]] = None, topk: Optional[PositiveInt] = None, reverse: bool = True, *args, **kwargs): """ Initialization method. :param field_key: Selector based on the specified value corresponding to the target key. The target key corresponding to multi-level field information need to be separated by '.'. :param top_ratio: Ratio of selected top specified field value, samples will be selected if their specified field values are within this parameter. When both topk and top_ratio are set, the value corresponding to the smaller number of samples will be applied. :param topk: Number of selected top specified field value, samples will be selected if their specified field values are within this parameter. When both topk and top_ratio are set, the value corresponding to the smaller number of samples will be applied. :param reverse: Determine the sorting rule, if reverse=True, then sort in descending order. :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) self.field_key = field_key self.top_ratio = top_ratio self.topk = topk self.reverse = reverse
[docs] def process(self, dataset): if len(dataset) <= 1 or not self.field_key: return dataset field_keys = self.field_key.split('.') assert field_keys[0] in dataset.features.keys( ), "'{}' not in {}".format(field_keys[0], dataset.features.keys()) field_value_dict = {} for i, item in enumerate(dataset[field_keys[0]]): field_value = item for key in field_keys[1:]: assert key in field_value.keys(), "'{}' not in {}".format( key, field_value.keys()) field_value = field_value[key] assert field_value is None or isinstance( field_value, str) or isinstance( field_value, numbers.Number ), 'The {} item is not String, Numbers or NoneType'.format(i) if field_value not in field_value_dict.keys(): field_value_dict[field_value] = [i] else: field_value_dict[field_value].append(i) select_num = 0 if not self.top_ratio: if not self.topk: return dataset else: select_num = self.topk else: select_num = self.top_ratio * len(field_value_dict) if self.topk and self.topk < select_num: select_num = self.topk select_index = sum( sorted(field_value_dict.values(), key=lambda x: len(x), reverse=self.reverse)[:int(select_num)], []) return dataset.select(select_index)