Source code for data_juicer.analysis.column_wise_analysis

import math
import os

import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

from data_juicer.utils.constant import Fields

from .overall_analysis import OverallAnalysis


def get_row_col(total_num, factor=2):
    """
    Given the total number of stats figures, get the "best" number of rows and
    columns. This function is needed when we need to store all stats figures
    into one image.

    :param total_num: Total number of stats figures
    :param factor: Number of sub-figure types in each figure. In
        default, it's 2, which means there are histogram and box plot
        for each stat figure
    :return: "best" number of rows and columns, and the grid list
    """
    n = total_num * factor  # actual number of figures
    now_col = factor  # search from the minimum number of columns
    now_row = total_num
    for col in range(factor, n + 1, factor):
        row = n * 1.0 / col
        if row != int(row):  # skip non-integer results
            continue
        if col > row:
            # object: minimum the difference between number of columns and rows
            if abs(col - row) > abs(now_col - now_row):
                break
            else:
                now_row = row
                now_col = col
                break
        now_row = row
        now_col = col

    # different sub-figures of the same stats should be in the same row
    now_col = now_col // factor

    # get grid indexes
    grids = []
    for i in range(total_num):
        grids.append((i // now_col, i % now_col))

    return int(now_row), int(now_col), grids


[docs]class ColumnWiseAnalysis: """Apply analysis on each column of stats respectively."""
[docs] def __init__(self, dataset, output_path, overall_result=None, save_stats_in_one_file=True): """ Initialization method :param dataset: the dataset to be analyzed :param output_path: path to store the analysis results :param overall_result: optional precomputed overall stats result :param save_stats_in_one_file: whether save all analysis figures of all stats into one image file """ self.stats = pd.DataFrame(dataset[Fields.stats]) self.output_path = output_path if not os.path.exists(self.output_path): os.makedirs(self.output_path) # if no overall description provided, analyze it from scratch if overall_result is None: oa = OverallAnalysis(dataset, output_path) overall_result = oa.analyze() self.overall_result = overall_result self.save_stats_in_one_file = save_stats_in_one_file
[docs] def analyze(self, show_percentiles=False, show=False, skip_export=False): """ Apply analysis and draw the analysis figure for stats. :param show_percentiles: whether to show the percentile line in each sub-figure. If it's true, there will be several red lines to indicate the quantiles of the stats distributions :param show: whether to show in a single window after drawing :param skip_export: whether save the results into disk :return: """ # number of sub-figures for each stat. There are histogram and box plot # for now, so it's 2. num_subcol = 2 # Default width and height unit for each sub-figure width_unit = 4 height_unit = 6 columns = self.stats.columns num = len(columns) # get the recommended "best" number of columns and rows rec_row, rec_col, grid_indexes = get_row_col(num, num_subcol) if self.save_stats_in_one_file: # if save_stats_in_one_file is opened, use recommended "best" # number of columns and rows to initialize the image panel. rec_width = rec_col * num_subcol * width_unit rec_height = rec_row * height_unit fig = plt.figure(figsize=(rec_width, rec_height), layout='constrained') subfigs = fig.subfigures(rec_row, rec_col, wspace=0.01) for i, column_name in enumerate(tqdm(columns.to_list(), desc='Column')): data = self.stats[column_name] # explode data to flatten inner list data = data.explode().infer_objects() grid = grid_indexes[i] if self.save_stats_in_one_file: if rec_col == 1: grid = grid[0] elif rec_row == 1: grid = grid[1] if rec_col == 1 and rec_row == 1: subfig = subfigs else: subfig = subfigs[grid] subfig.set_facecolor('0.85') # numeric or string via nan. Apply different plot method for them. if pd.isna(self.overall_result[column_name].get('top')): # numeric or numeric list -- draw histogram and box plot for # this stat percentiles = self.overall_result[column_name] \ if show_percentiles else None # get axes for each subplot if self.save_stats_in_one_file: axes = subfig.subplots(1, num_subcol) else: axes = [None] * num_subcol # draw histogram self.draw_hist(axes[0], data, os.path.join(self.output_path, f'{column_name}-hist.png'), percentiles=percentiles) # draw box self.draw_box(axes[1], data, os.path.join(self.output_path, f'{column_name}-box.png'), percentiles=percentiles) else: # object (string) or string list -- only draw histogram for # this stat if self.save_stats_in_one_file: axes = subfig.subplots(1, 1) else: axes = None if not skip_export: self.draw_hist( axes, data, os.path.join(self.output_path, f'{column_name}-hist.png')) # add a title to the figure of this stat if self.save_stats_in_one_file: subfig.suptitle(f'{data.name}', fontsize='x-large', fontweight='bold') if self.save_stats_in_one_file: fig = plt.gcf() if not skip_export: fig.savefig(os.path.join(self.output_path, 'all-stats.png')) if show: plt.show() else: pass # TODO: (fixme) the saved png sometime are blank plt.clf()
[docs] def draw_hist(self, ax, data, save_path, percentiles=None, show=False): """ Draw the histogram for the data. :param ax: the axes to draw :param data: data to draw :param save_path: the path to save the histogram figure :param percentiles: the overall analysis result of the data including percentile information :param show: whether to show in a single window after drawing :return: """ # recommended number of bins data_num = len(data) if data_num >= 100: rec_bins = int(math.sqrt(len(data))) else: rec_bins = None # if ax is None, using plot method in pandas if ax is None: ax = data.hist(bins=rec_bins, figsize=(20, 16)) else: ax.hist(data, bins=rec_bins) # set axes ax.set_xlabel(data.name) ax.set_ylabel('Count') # draw percentile lines if it's not None if percentiles is not None: ymin, ymax = ax.get_ylim() for percentile in percentiles.keys(): # skip other information if percentile in {'count', 'unique', 'top', 'freq', 'std'}: continue value = percentiles[percentile] ax.vlines(x=value, ymin=ymin, ymax=ymax, colors='r') ax.text(x=value, y=ymax, s=percentile, rotation=30, color='r') ax.text(x=value, y=ymax * 0.97, s=str(round(value, 3)), rotation=30, color='r') if not self.save_stats_in_one_file: # save into file plt.savefig(save_path) if show: plt.show() else: # if no showing, we need to clear this axes to avoid # accumulated overlapped figures in different draw_xxx function # calling ax.clear() else: # add a little rotation on labels of x axis to avoid overlapping ax.tick_params(axis='x', rotation=25)
[docs] def draw_box(self, ax, data, save_path, percentiles=None, show=False): """ Draw the box plot for the data. :param ax: the axes to draw :param data: data to draw :param save_path: the path to save the box figure :param percentiles: the overall analysis result of the data including percentile information :param show: whether to show in a single window after drawing :return: """ # if ax is None, using plot method in pandas if ax is None: ax = data.plot.box(figsize=(20, 16)) else: ax.boxplot(data) # set axes ax.set_ylabel(data.name) # draw percentile lines if it's not None if percentiles is not None: xmin, xmax = ax.get_xlim() for percentile in percentiles.keys(): # skip other information if percentile in {'count', 'unique', 'top', 'freq', 'std'}: continue value = percentiles[percentile] ax.hlines(y=value, xmin=xmin, xmax=xmax, colors='r') ax.text(y=value, x=xmin + (xmax - xmin) * 0.6, s=f'{percentile}: {round(value, 3)}', color='r') if not self.save_stats_in_one_file: # save into file plt.savefig(save_path) if show: plt.show() else: # if no showing, we need to clear this axes to avoid # accumulated overlapped figures in different draw_xxx function # calling ax.clear()