Source code for data_juicer.analysis.column_wise_analysis

import math
import os

import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from wordcloud import WordCloud

from data_juicer.utils.constant import DEFAULT_PREFIX, Fields

from .overall_analysis import OverallAnalysis


[docs] def get_row_col(total_num, factor=2): """ Given the total number of stats figures, get the "best" number of rows and columns. This function is needed when we need to store all stats figures into one image. :param total_num: Total number of stats figures :param factor: Number of sub-figure types in each figure. In default, it's 2, which means there are histogram and box plot for each stat figure :return: "best" number of rows and columns, and the grid list """ n = total_num * factor # actual number of figures now_col = factor # search from the minimum number of columns now_row = total_num for col in range(factor, n + 1, factor): row = n * 1.0 / col if row != int(row): # skip non-integer results continue if col > row: # object: minimum the difference between number of columns and rows if abs(col - row) > abs(now_col - now_row): break else: now_row = row now_col = col break now_row = row now_col = col # different sub-figures of the same stats should be in the same row now_col = now_col // factor # get grid indexes grids = [] for i in range(total_num): grids.append((i // now_col, i % now_col)) return int(now_row), int(now_col), grids
[docs] class ColumnWiseAnalysis: """Apply analysis on each column of stats respectively."""
[docs] def __init__(self, dataset, output_path, overall_result=None, save_stats_in_one_file=True): """ Initialization method :param dataset: the dataset to be analyzed :param output_path: path to store the analysis results :param overall_result: optional precomputed overall stats result :param save_stats_in_one_file: whether save all analysis figures of all stats into one image file """ self.stats = pd.DataFrame(dataset[Fields.stats]) self.meta = pd.DataFrame(dataset[Fields.meta]) # remove non-tag columns meta_columns = self.meta.columns for col_name in meta_columns: if not col_name.startswith(DEFAULT_PREFIX): self.meta = self.meta.drop(col_name, axis=1) self.output_path = output_path if not os.path.exists(self.output_path): os.makedirs(self.output_path) # if no overall description provided, analyze it from scratch if overall_result is None: oa = OverallAnalysis(dataset, output_path) overall_result = oa.analyze() self.overall_result = overall_result self.save_stats_in_one_file = save_stats_in_one_file
[docs] def analyze(self, show_percentiles=False, show=False, skip_export=False): """ Apply analysis and draw the analysis figure for stats. :param show_percentiles: whether to show the percentile line in each sub-figure. If it's true, there will be several red lines to indicate the quantiles of the stats distributions :param show: whether to show in a single window after drawing :param skip_export: whether save the results into disk :return: """ # number of sub-figures for each stat. There are histogram and box plot # for now, so it's 2. num_subcol = 2 # Default width and height unit for each sub-figure width_unit = 4 height_unit = 6 stats_and_meta = pd.concat([self.stats, self.meta], axis=1) all_columns = stats_and_meta.columns num = len(all_columns) # get the recommended "best" number of columns and rows rec_row, rec_col, grid_indexes = get_row_col(num, num_subcol) if self.save_stats_in_one_file: # if save_stats_in_one_file is opened, use recommended "best" # number of columns and rows to initialize the image panel. rec_width = rec_col * num_subcol * width_unit rec_height = rec_row * height_unit fig = plt.figure(figsize=(rec_width, rec_height), layout='constrained') subfigs = fig.subfigures(rec_row, rec_col, wspace=0.01) for i, column_name in enumerate( tqdm(all_columns.to_list(), desc='Column')): data = stats_and_meta[column_name] # explode data to flatten inner list data = data.explode().infer_objects() grid = grid_indexes[i] if self.save_stats_in_one_file: if rec_col == 1: grid = grid[0] elif rec_row == 1: grid = grid[1] if rec_col == 1 and rec_row == 1: subfig = subfigs else: subfig = subfigs[grid] subfig.set_facecolor('0.85') # numeric or string via nan. Apply different plot method for them. if pd.isna(self.overall_result[column_name].get('top')): # numeric or numeric list -- draw histogram and box plot for # this stat percentiles = self.overall_result[column_name] \ if show_percentiles else None # get axes for each subplot if self.save_stats_in_one_file: axes = subfig.subplots(1, num_subcol) else: axes = [None] * num_subcol if not skip_export: # draw histogram self.draw_hist(axes[0], data, os.path.join(self.output_path, f'{column_name}-hist.png'), percentiles=percentiles) # draw box self.draw_box(axes[1], data, os.path.join(self.output_path, f'{column_name}-box.png'), percentiles=percentiles) else: # object (string) or string list -- only draw histogram for # this stat if self.save_stats_in_one_file: axes = subfig.subplots(1, num_subcol) else: axes = [None] * num_subcol if not skip_export: self.draw_hist( axes[0], data, os.path.join(self.output_path, f'{column_name}-hist.png')) self.draw_wordcloud( axes[1], data, os.path.join(self.output_path, f'{column_name}-wordcloud.png')) # add a title to the figure of this stat if self.save_stats_in_one_file: subfig.suptitle(f'{data.name}', fontsize='x-large', fontweight='bold') if self.save_stats_in_one_file: fig = plt.gcf() if not skip_export: fig.savefig(os.path.join(self.output_path, 'all-stats.png')) if show: plt.show() else: pass # TODO: (fixme) the saved png sometime are blank plt.clf()
[docs] def draw_hist(self, ax, data, save_path, percentiles=None, show=False): """ Draw the histogram for the data. :param ax: the axes to draw :param data: data to draw :param save_path: the path to save the histogram figure :param percentiles: the overall analysis result of the data including percentile information :param show: whether to show in a single window after drawing :return: """ # recommended number of bins data_num = len(data) rec_bins = max(int(math.sqrt(data_num)), 10) # if ax is None, using plot method in pandas if ax is None: ax = data.hist(bins=rec_bins, figsize=(20, 16)) else: ax.hist(data, bins=rec_bins) # set axes ax.set_xlabel(data.name) ax.set_ylabel('Count') # draw percentile lines if it's not None if percentiles is not None: ymin, ymax = ax.get_ylim() for percentile in percentiles.keys(): # skip other information if percentile in {'count', 'unique', 'top', 'freq', 'std'}: continue value = percentiles[percentile] ax.vlines(x=value, ymin=ymin, ymax=ymax, colors='r') ax.text(x=value, y=ymax, s=percentile, rotation=30, color='r') ax.text(x=value, y=ymax * 0.97, s=str(round(value, 3)), rotation=30, color='r') if not self.save_stats_in_one_file: # save into file plt.savefig(save_path) if show: plt.show() else: # if no showing, we need to clear this axes to avoid # accumulated overlapped figures in different draw_xxx function # calling ax.clear() else: # add a little rotation on labels of x axis to avoid overlapping ax.tick_params(axis='x', rotation=25)
[docs] def draw_box(self, ax, data, save_path, percentiles=None, show=False): """ Draw the box plot for the data. :param ax: the axes to draw :param data: data to draw :param save_path: the path to save the box figure :param percentiles: the overall analysis result of the data including percentile information :param show: whether to show in a single window after drawing :return: """ # if ax is None, using plot method in pandas if ax is None: ax = data.plot.box(figsize=(20, 16)) else: ax.boxplot(data) # set axes ax.set_ylabel(data.name) # draw percentile lines if it's not None if percentiles is not None: xmin, xmax = ax.get_xlim() for percentile in percentiles.keys(): # skip other information if percentile in {'count', 'unique', 'top', 'freq', 'std'}: continue value = percentiles[percentile] ax.hlines(y=value, xmin=xmin, xmax=xmax, colors='r') ax.text(y=value, x=xmin + (xmax - xmin) * 0.6, s=f'{percentile}: {round(value, 3)}', color='r') if not self.save_stats_in_one_file: # save into file plt.savefig(save_path) if show: plt.show() else: # if no showing, we need to clear this axes to avoid # accumulated overlapped figures in different draw_xxx function # calling ax.clear()
[docs] def draw_wordcloud(self, ax, data, save_path, show=False): word_list = data.tolist() word_nums = {} for w in word_list: if w in word_nums: word_nums[w] += 1 else: word_nums[w] = 1 wc = WordCloud(width=400, height=320) wc.generate_from_frequencies(word_nums) if ax is None: ax = plt.figure(figsize=(20, 16)) else: ax.imshow(wc, interpolation='bilinear') ax.axis('off') if not self.save_stats_in_one_file: # save into file wc.to_file(save_path) if show: plt.show() else: # if no showing, we need to clear this axes to avoid # accumulated overlapped figures in different draw_xxx function # calling ax.clear()