Source code for data_juicer.ops.selector.topk_specified_field_selector

import heapq
from typing import Optional

from pydantic import Field, PositiveInt
from typing_extensions import Annotated

from data_juicer.utils.common_utils import stats_to_number

from ..base_op import OPERATORS, Selector


[docs]@OPERATORS.register_module('topk_specified_field_selector') class TopkSpecifiedFieldSelector(Selector): """Selector to select top samples based on the sorted specified field value."""
[docs] def __init__(self, field_key: str = '', top_ratio: Optional[Annotated[float, Field(ge=0, le=1)]] = None, topk: Optional[PositiveInt] = None, reverse: bool = True, *args, **kwargs): """ Initialization method. :param field_key: Selector based on the specified value corresponding to the target key. The target key corresponding to multi-level field information need to be separated by '.'. :param top_ratio: Ratio of selected top samples, samples will be selected if their specified field values are within this parameter. When both topk and top_ratio are set, the value corresponding to the smaller number of samples will be applied. :param topk: Number of selected top sample, samples will be selected if their specified field values are within this parameter. When both topk and top_ratio are set, the value corresponding to the smaller number of samples will be applied. :param reverse: Determine the sorting rule, if reverse=True, then sort in descending order. :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) self.field_key = field_key self.top_ratio = top_ratio self.topk = topk self.reverse = reverse
[docs] def process(self, dataset): if len(dataset) <= 1 or not self.field_key: return dataset select_num = 0 if not self.top_ratio: if not self.topk: return dataset else: select_num = self.topk else: select_num = self.top_ratio * len(dataset) if self.topk and self.topk < select_num: select_num = self.topk field_keys = self.field_key.split('.') assert field_keys[0] in dataset.features.keys( ), "'{}' not in {}".format(field_keys[0], dataset.features.keys()) if len(field_keys) == 1: field_value_list = dataset[field_keys[0]] else: field_value_list = [] for item in dataset[field_keys[0]]: field_value = item for key in field_keys[1:]: assert key in field_value.keys(), "'{}' not in {}".format( key, field_value.keys()) field_value = field_value[key] field_value_list.append( stats_to_number(field_value, self.reverse)) if self.reverse: select_index = heapq.nlargest(int(select_num), range(len(dataset)), field_value_list.__getitem__) else: select_index = heapq.nsmallest(int(select_num), range(len(dataset)), field_value_list.__getitem__) return dataset.select(select_index)