Source code for data_juicer.analysis.draw

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


[docs] def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False): """ Draw heatmap of input data with special lables. :param data: input data, now support [`list`, `tuple`, `numpy array`, 'torch tensor'] :param xlabels: x axis labels. :param ylabels: y axis labels, if None, use xlabels. :param figsize: figure size. :param triangle: only display triangle. :return: a plot figure. """ figsize = figsize if figsize else (8 * 2.5, 6 * 2.5) _, ax = plt.subplots(figsize=figsize) mask = None if triangle: mask = np.triu(np.ones_like(data)) ax.tick_params( right=True, top=True, labelright=True, labeltop=True, ) sns.heatmap(data, ax=ax, cmap='Oranges', annot=True, mask=mask, linewidths=.05, square=True, xticklabels=xlabels, yticklabels=ylables, annot_kws={'size': 8}) plt.subplots_adjust(left=.1, right=0.95, bottom=0.22, top=0.95) fig = plt.gcf() plt.show() return fig