07.a stats.2dbin and mpl.heatmap

Use binned_statistic_2d and display using heatmap.

r1 (count), r2 (count), r3 (median), r4 (mean)

Out:

        Unnamed: 0  sample  timestep                       features  feature_values  shap_values
0                0       0         0                   Ward Lactate        0.000000     0.000652
1                1       0         0                   Ward Glucose        0.000000    -0.000596
2                2       0         0                       Ward sO2        0.000000     0.000231
3                3       0         0  White blood cell count, blood        0.000000     0.000582
4                4       0         0                      Platelets        0.000000    -0.001705
...            ...     ...       ...                            ...             ...          ...
251995      251995     999         6                  Procalcitonin        0.000000     0.000027
251996      251996     999         6                       Ferritin        0.000000    -0.001375
251997      251997     999         6                        D-Dimer        0.000000     0.000045
251998      251998     999         6                            sex       -1.000000    -0.002359
251999      251999     999         6                            age        0.169952     0.000237

[252000 rows x 6 columns]
<__array_function__ internals>:180: UserWarning:

Warning: converting a masked element to nan.

/Users/cbit/Desktop/repositories/environments/venv-py3109-python-spare-code/lib/python3.10/site-packages/matplotlib/colors.py:1311: UserWarning:

Warning: converting a masked element to nan.

/Users/cbit/Desktop/repositories/environments/venv-py3109-python-spare-code/lib/python3.10/site-packages/matplotlib/ticker.py:374: FutureWarning:

Format strings passed to MaskedConstant are ignored, but in future may error or produce different behavior

  9 import matplotlib
 10 import numpy as np
 11 import pandas as pd
 12 import matplotlib as mpl
 13 import matplotlib.pyplot as plt
 14
 15 from scipy import stats
 16
 17 # See https://matplotlib.org/devdocs/users/explain/customizing.html
 18 mpl.rcParams['font.size'] = 8
 19 mpl.rcParams['axes.titlesize'] = 8
 20 mpl.rcParams['axes.labelsize'] = 8
 21 mpl.rcParams['xtick.labelsize'] = 8
 22 mpl.rcParams['ytick.labelsize'] = 8
 23
 24 def heatmap(data, row_labels, col_labels, ax=None,
 25             cbar_kw=None, cbarlabel="", **kwargs):
 26     """
 27     Create a heatmap from a numpy array and two lists of labels.
 28
 29     Parameters
 30     ----------
 31     data
 32         A 2D numpy array of shape (M, N).
 33     row_labels
 34         A list or array of length M with the labels for the rows.
 35     col_labels
 36         A list or array of length N with the labels for the columns.
 37     ax
 38         A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
 39         not provided, use current axes or create a new one.  Optional.
 40     cbar_kw
 41         A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
 42     cbarlabel
 43         The label for the colorbar.  Optional.
 44     **kwargs
 45         All other arguments are forwarded to `imshow`.
 46     """
 47
 48     if ax is None:
 49         ax = plt.gca()
 50
 51     if cbar_kw is None:
 52         cbar_kw = {}
 53
 54     # Plot the heatmap
 55     im = ax.imshow(data, **kwargs)
 56
 57     # Create colorbar
 58     cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
 59     cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
 60
 61     # Show all ticks and label them with the respective list entries.
 62     ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
 63     ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
 64
 65     # Let the horizontal axes labeling appear on top.
 66     ax.tick_params(top=True, bottom=False,
 67                    labeltop=True, labelbottom=False)
 68
 69     # Rotate the tick labels and set their alignment.
 70     plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
 71              rotation_mode="anchor")
 72
 73     # Turn spines off and create white grid.
 74     ax.spines[:].set_visible(False)
 75
 76     ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
 77     ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
 78     ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
 79     ax.tick_params(which="minor", bottom=False, left=False)
 80
 81     return im, cbar
 82
 83
 84 def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
 85                      textcolors=("black", "white"),
 86                      threshold=None, **textkw):
 87     """
 88     A function to annotate a heatmap.
 89
 90     Parameters
 91     ----------
 92     im
 93         The AxesImage to be labeled.
 94     data
 95         Data used to annotate.  If None, the image's data is used.  Optional.
 96     valfmt
 97         The format of the annotations inside the heatmap.  This should either
 98         use the string format method, e.g. "$ {x:.2f}", or be a
 99         `matplotlib.ticker.Formatter`.  Optional.
100     textcolors
101         A pair of colors.  The first is used for values below a threshold,
102         the second for those above.  Optional.
103     threshold
104         Value in data units according to which the colors from textcolors are
105         applied.  If None (the default) uses the middle of the colormap as
106         separation.  Optional.
107     **kwargs
108         All other arguments are forwarded to each call to `text` used to create
109         the text labels.
110     """
111
112     if not isinstance(data, (list, np.ndarray)):
113         data = im.get_array()
114
115     # Normalize the threshold to the images color range.
116     if threshold is not None:
117         threshold = im.norm(threshold)
118     else:
119         threshold = im.norm(data.max())/2.
120
121     # Set default alignment to center, but allow it to be
122     # overwritten by textkw.
123     kw = dict(horizontalalignment="center",
124               verticalalignment="center")
125     kw.update(textkw)
126
127     # Get the formatter in case a string is supplied
128     if isinstance(valfmt, str):
129         valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
130
131     # Loop over the data and create a `Text` for each "pixel".
132     # Change the text's color depending on the data.
133     texts = []
134     for i in range(data.shape[0]):
135         for j in range(data.shape[1]):
136             kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
137             text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
138             texts.append(text)
139
140     return texts
141
142
143 def plot_binned_statistic(r, ax, title=None, astype=None, **kwargs):
144     """Plots the binned statistic
145
146     Parameters
147     ----------
148     r: the binned statistic
149     ax: the axes to plot
150
151     Returns
152     -------
153     """
154     # Variables
155     rows, cols = r.statistic.shape
156
157     # Compute centers
158     x_center = (r.x_edge[:-1] + r.x_edge[1:]) / 2
159     y_center = (r.y_edge[:-1] + r.y_edge[1:]) / 2
160
161     # Plot heatmap (matplotlib sample, use seaborn instead)
162     im, cbar = heatmap(r.statistic,
163         np.around(x_center, 2), np.around(y_center, 2), ax=ax,
164         cmap="coolwarm", cbarlabel="value [unit]")
165     texts = annotate_heatmap(im, **kwargs)
166
167     # Configure
168     ax.set_aspect('equal', 'box')
169     if title is not None:
170         ax.set_title(title)
171
172     """
173     # Show
174     print("\n\n")
175     print(matrix)
176     print(r.x_edge)
177     print(r.y_edge)
178     print(r.binnumber)
179     print(np.flip(r.statistic, axis=1))
180     """
181
182 def data_manual():
183     """"""
184     # Create random values
185     x = np.array([1, 1, 1, 1, 2, 2, 2, 3, 4])
186     y = np.array([1, 1, 2, 2, 3, 4, 5, 6, 7])
187     z = np.array([1, 9, 9, 1, 2, 2, 2, 3, 4])
188     return x, y, z
189
190 def data_shap():
191     """"""
192     data = pd.read_csv('../../datasets/shap/shap.csv')
193     print(data)
194     return data.timestep, data.shap_values, data.feature_values
195
196
197
198
199 # Load data
200 #x, y, z = data_manual()
201 x, y, z = data_shap()
202
203 # Using np.arange
204 binx = np.arange(0, x.max()+1) + 0.5 # [0.5, 1.5, 2.5, ...., N + 0.5]
205 biny = np.arange(0, y.max()+1) + 0.5 # [0.5, 1.5, 2.5, ...., N + 0.5]
206
207 # Using np.linspace
208 biny = np.linspace(y.min(), y.max(), 10)
209
210 # Manual
211 #binx = np.arange(5) + 0.5
212 #biny = np.arange(8) + 0.5
213
214 # Compute binned statistic (count)
215 r1 = stats.binned_statistic_2d(x=x, y=y, values=None,
216     statistic='count', bins=[binx, biny],
217     expand_binnumbers=True)
218
219 # Compute binned statistic (median)
220 r2 = stats.binned_statistic_2d(x=x, y=y, values=z,
221     statistic='count', bins=[4, 7],
222     expand_binnumbers=False)
223
224 # Compute binned statistic (median)
225 r3 = stats.binned_statistic_2d(x=x, y=y, values=z,
226     statistic='median', bins=[binx, biny],
227     expand_binnumbers=False)
228
229 # Compute binned statistic (median)
230 r4 = stats.binned_statistic_2d(x=x, y=y, values=z,
231     statistic='mean', bins=[binx, biny],
232     expand_binnumbers=False)
233
234
235 # Plot
236 fig, axs = plt.subplots(nrows=2, ncols=2,
237     sharey=True, sharex=True, figsize=(14, 7))
238 plot_binned_statistic(r1, axs[0,0], title='r1 (count)', valfmt="{x:g}")
239 plot_binned_statistic(r2, axs[0,1], title='r2 (count)', valfmt="{x:g}")
240 plot_binned_statistic(r3, axs[1,0], title='r3 (median)')
241 plot_binned_statistic(r3, axs[1,1], title='r4 (mean)')
242
243 # Display
244 plt.tight_layout()
245 plt.show()

Total running time of the script: ( 0 minutes 1.063 seconds)

Gallery generated by Sphinx-Gallery