07.a stats.2dbin and mpl.heatmap

This script demonstrates how to aggregate 2D data using scipy.stats.binned_statistic_2d and then visualize the result as a detailed, annotated heatmap.

It provides a comprehensive workflow that includes:

  • Binning Data: It groups scattered 2D points into a grid and computes statistics like count, median, and mean for the values within each bin.

  • Custom Heatmap Function: It uses custom helper functions to build a polished heatmap from scratch using matplotlib, complete with annotations for each cell.

r1 (count), r2 (count), r3 (median), r4 (mean)

Out:

        Unnamed: 0  sample  timestep                       features  feature_values  shap_values
0                0       0         0                   Ward Lactate        0.000000     0.000652
1                1       0         0                   Ward Glucose        0.000000    -0.000596
2                2       0         0                       Ward sO2        0.000000     0.000231
3                3       0         0  White blood cell count, blood        0.000000     0.000582
4                4       0         0                      Platelets        0.000000    -0.001705
...            ...     ...       ...                            ...             ...          ...
251995      251995     999         6                  Procalcitonin        0.000000     0.000027
251996      251996     999         6                       Ferritin        0.000000    -0.001375
251997      251997     999         6                        D-Dimer        0.000000     0.000045
251998      251998     999         6                            sex       -1.000000    -0.002359
251999      251999     999         6                            age        0.169952     0.000237

[252000 rows x 6 columns]
C:\Users\kelda\Desktop\repositories\virtualenvs\venv-py311-psc\Lib\site-packages\matplotlib\colors.py:2242: UserWarning:

Warning: converting a masked element to nan.

C:\Users\kelda\Desktop\repositories\virtualenvs\venv-py311-psc\Lib\site-packages\matplotlib\colors.py:2249: UserWarning:

Warning: converting a masked element to nan.

C:\Users\kelda\AppData\Local\Programs\Python\Python311\Lib\string.py:264: FutureWarning:

Format strings passed to MaskedConstant are ignored, but in future may error or produce different behavior

C:\Users\kelda\Desktop\repositories\virtualenvs\venv-py311-psc\Lib\site-packages\matplotlib\colors.py:2242: UserWarning:

Warning: converting a masked element to nan.

C:\Users\kelda\Desktop\repositories\virtualenvs\venv-py311-psc\Lib\site-packages\matplotlib\colors.py:2249: UserWarning:

Warning: converting a masked element to nan.

C:\Users\kelda\AppData\Local\Programs\Python\Python311\Lib\string.py:264: FutureWarning:

Format strings passed to MaskedConstant are ignored, but in future may error or produce different behavior

C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\matplotlib\plot_main07_a_2dbin_stat.py:256: UserWarning:

FigureCanvasAgg is non-interactive, and thus cannot be shown

 20 import matplotlib
 21 import numpy as np
 22 import pandas as pd
 23 import matplotlib as mpl
 24 import matplotlib.pyplot as plt
 25
 26 from scipy import stats
 27
 28 # See https://matplotlib.org/devdocs/users/explain/customizing.html
 29 mpl.rcParams['font.size'] = 8
 30 mpl.rcParams['axes.titlesize'] = 8
 31 mpl.rcParams['axes.labelsize'] = 8
 32 mpl.rcParams['xtick.labelsize'] = 8
 33 mpl.rcParams['ytick.labelsize'] = 8
 34
 35 def heatmap(data, row_labels, col_labels, ax=None,
 36             cbar_kw=None, cbarlabel="", **kwargs):
 37     """
 38     Create a heatmap from a numpy array and two lists of labels.
 39
 40     Parameters
 41     ----------
 42     data
 43         A 2D numpy array of shape (M, N).
 44     row_labels
 45         A list or array of length M with the labels for the rows.
 46     col_labels
 47         A list or array of length N with the labels for the columns.
 48     ax
 49         A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
 50         not provided, use current axes or create a new one.  Optional.
 51     cbar_kw
 52         A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
 53     cbarlabel
 54         The label for the colorbar.  Optional.
 55     **kwargs
 56         All other arguments are forwarded to `imshow`.
 57     """
 58
 59     if ax is None:
 60         ax = plt.gca()
 61
 62     if cbar_kw is None:
 63         cbar_kw = {}
 64
 65     # Plot the heatmap
 66     im = ax.imshow(data, **kwargs)
 67
 68     # Create colorbar
 69     cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
 70     cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
 71
 72     # Show all ticks and label them with the respective list entries.
 73     ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
 74     ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
 75
 76     # Let the horizontal axes labeling appear on top.
 77     ax.tick_params(top=True, bottom=False,
 78                    labeltop=True, labelbottom=False)
 79
 80     # Rotate the tick labels and set their alignment.
 81     plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
 82              rotation_mode="anchor")
 83
 84     # Turn spines off and create white grid.
 85     ax.spines[:].set_visible(False)
 86
 87     ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
 88     ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
 89     ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
 90     ax.tick_params(which="minor", bottom=False, left=False)
 91
 92     return im, cbar
 93
 94
 95 def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
 96                      textcolors=("black", "white"),
 97                      threshold=None, **textkw):
 98     """
 99     A function to annotate a heatmap.
100
101     Parameters
102     ----------
103     im
104         The AxesImage to be labeled.
105     data
106         Data used to annotate.  If None, the image's data is used.  Optional.
107     valfmt
108         The format of the annotations inside the heatmap.  This should either
109         use the string format method, e.g. "$ {x:.2f}", or be a
110         `matplotlib.ticker.Formatter`.  Optional.
111     textcolors
112         A pair of colors.  The first is used for values below a threshold,
113         the second for those above.  Optional.
114     threshold
115         Value in data units according to which the colors from textcolors are
116         applied.  If None (the default) uses the middle of the colormap as
117         separation.  Optional.
118     **kwargs
119         All other arguments are forwarded to each call to `text` used to create
120         the text labels.
121     """
122
123     if not isinstance(data, (list, np.ndarray)):
124         data = im.get_array()
125
126     # Normalize the threshold to the images color range.
127     if threshold is not None:
128         threshold = im.norm(threshold)
129     else:
130         threshold = im.norm(data.max())/2.
131
132     # Set default alignment to center, but allow it to be
133     # overwritten by textkw.
134     kw = dict(horizontalalignment="center",
135               verticalalignment="center")
136     kw.update(textkw)
137
138     # Get the formatter in case a string is supplied
139     if isinstance(valfmt, str):
140         valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
141
142     # Loop over the data and create a `Text` for each "pixel".
143     # Change the text's color depending on the data.
144     texts = []
145     for i in range(data.shape[0]):
146         for j in range(data.shape[1]):
147             kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
148             text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
149             texts.append(text)
150
151     return texts
152
153
154 def plot_binned_statistic(r, ax, title=None, astype=None, **kwargs):
155     """Plots the binned statistic
156
157     Parameters
158     ----------
159     r: the binned statistic
160     ax: the axes to plot
161
162     Returns
163     -------
164     """
165     # Variables
166     rows, cols = r.statistic.shape
167
168     # Compute centers
169     x_center = (r.x_edge[:-1] + r.x_edge[1:]) / 2
170     y_center = (r.y_edge[:-1] + r.y_edge[1:]) / 2
171
172     # Plot heatmap (matplotlib sample, use seaborn instead)
173     im, cbar = heatmap(r.statistic,
174         np.around(x_center, 2), np.around(y_center, 2), ax=ax,
175         cmap="coolwarm", cbarlabel="value [unit]")
176     texts = annotate_heatmap(im, **kwargs)
177
178     # Configure
179     ax.set_aspect('equal', 'box')
180     if title is not None:
181         ax.set_title(title)
182
183     """
184     # Show
185     print("\n\n")
186     print(matrix)
187     print(r.x_edge)
188     print(r.y_edge)
189     print(r.binnumber)
190     print(np.flip(r.statistic, axis=1))
191     """
192
193 def data_manual():
194     """"""
195     # Create random values
196     x = np.array([1, 1, 1, 1, 2, 2, 2, 3, 4])
197     y = np.array([1, 1, 2, 2, 3, 4, 5, 6, 7])
198     z = np.array([1, 9, 9, 1, 2, 2, 2, 3, 4])
199     return x, y, z
200
201 def data_shap():
202     """"""
203     data = pd.read_csv('../../datasets/shap/shap.csv')
204     print(data)
205     return data.timestep, data.shap_values, data.feature_values
206
207
208
209
210 # Load data
211 #x, y, z = data_manual()
212 x, y, z = data_shap()
213
214 # Using np.arange
215 binx = np.arange(0, x.max()+1) + 0.5 # [0.5, 1.5, 2.5, ...., N + 0.5]
216 biny = np.arange(0, y.max()+1) + 0.5 # [0.5, 1.5, 2.5, ...., N + 0.5]
217
218 # Using np.linspace
219 biny = np.linspace(y.min(), y.max(), 10)
220
221 # Manual
222 #binx = np.arange(5) + 0.5
223 #biny = np.arange(8) + 0.5
224
225 # Compute binned statistic (count)
226 r1 = stats.binned_statistic_2d(x=x, y=y, values=None,
227     statistic='count', bins=[binx, biny],
228     expand_binnumbers=True)
229
230 # Compute binned statistic (median)
231 r2 = stats.binned_statistic_2d(x=x, y=y, values=z,
232     statistic='count', bins=[4, 7],
233     expand_binnumbers=False)
234
235 # Compute binned statistic (median)
236 r3 = stats.binned_statistic_2d(x=x, y=y, values=z,
237     statistic='median', bins=[binx, biny],
238     expand_binnumbers=False)
239
240 # Compute binned statistic (median)
241 r4 = stats.binned_statistic_2d(x=x, y=y, values=z,
242     statistic='mean', bins=[binx, biny],
243     expand_binnumbers=False)
244
245
246 # Plot
247 fig, axs = plt.subplots(nrows=2, ncols=2,
248     sharey=True, sharex=True, figsize=(14, 7))
249 plot_binned_statistic(r1, axs[0,0], title='r1 (count)', valfmt="{x:g}")
250 plot_binned_statistic(r2, axs[0,1], title='r2 (count)', valfmt="{x:g}")
251 plot_binned_statistic(r3, axs[1,0], title='r3 (median)')
252 plot_binned_statistic(r3, axs[1,1], title='r4 (mean)')
253
254 # Display
255 plt.tight_layout()
256 plt.show()

Total running time of the script: ( 0 minutes 1.031 seconds)

Gallery generated by Sphinx-Gallery