Note
Click here to download the full example code
07.a stats.2dbin
and mpl.heatmap
Use binned_statistic_2d and display using heatmap.
Out:
Unnamed: 0 sample timestep features feature_values shap_values
0 0 0 0 Ward Lactate 0.000000 0.000652
1 1 0 0 Ward Glucose 0.000000 -0.000596
2 2 0 0 Ward sO2 0.000000 0.000231
3 3 0 0 White blood cell count, blood 0.000000 0.000582
4 4 0 0 Platelets 0.000000 -0.001705
... ... ... ... ... ... ...
251995 251995 999 6 Procalcitonin 0.000000 0.000027
251996 251996 999 6 Ferritin 0.000000 -0.001375
251997 251997 999 6 D-Dimer 0.000000 0.000045
251998 251998 999 6 sex -1.000000 -0.002359
251999 251999 999 6 age 0.169952 0.000237
[252000 rows x 6 columns]
<__array_function__ internals>:180: UserWarning:
Warning: converting a masked element to nan.
/Users/cbit/Desktop/repositories/environments/venv-py3109-python-spare-code/lib/python3.10/site-packages/matplotlib/colors.py:1311: UserWarning:
Warning: converting a masked element to nan.
/Users/cbit/Desktop/repositories/environments/venv-py3109-python-spare-code/lib/python3.10/site-packages/matplotlib/ticker.py:374: FutureWarning:
Format strings passed to MaskedConstant are ignored, but in future may error or produce different behavior
9 import matplotlib
10 import numpy as np
11 import pandas as pd
12 import matplotlib as mpl
13 import matplotlib.pyplot as plt
14
15 from scipy import stats
16
17 # See https://matplotlib.org/devdocs/users/explain/customizing.html
18 mpl.rcParams['font.size'] = 8
19 mpl.rcParams['axes.titlesize'] = 8
20 mpl.rcParams['axes.labelsize'] = 8
21 mpl.rcParams['xtick.labelsize'] = 8
22 mpl.rcParams['ytick.labelsize'] = 8
23
24 def heatmap(data, row_labels, col_labels, ax=None,
25 cbar_kw=None, cbarlabel="", **kwargs):
26 """
27 Create a heatmap from a numpy array and two lists of labels.
28
29 Parameters
30 ----------
31 data
32 A 2D numpy array of shape (M, N).
33 row_labels
34 A list or array of length M with the labels for the rows.
35 col_labels
36 A list or array of length N with the labels for the columns.
37 ax
38 A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
39 not provided, use current axes or create a new one. Optional.
40 cbar_kw
41 A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
42 cbarlabel
43 The label for the colorbar. Optional.
44 **kwargs
45 All other arguments are forwarded to `imshow`.
46 """
47
48 if ax is None:
49 ax = plt.gca()
50
51 if cbar_kw is None:
52 cbar_kw = {}
53
54 # Plot the heatmap
55 im = ax.imshow(data, **kwargs)
56
57 # Create colorbar
58 cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
59 cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
60
61 # Show all ticks and label them with the respective list entries.
62 ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
63 ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
64
65 # Let the horizontal axes labeling appear on top.
66 ax.tick_params(top=True, bottom=False,
67 labeltop=True, labelbottom=False)
68
69 # Rotate the tick labels and set their alignment.
70 plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
71 rotation_mode="anchor")
72
73 # Turn spines off and create white grid.
74 ax.spines[:].set_visible(False)
75
76 ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
77 ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
78 ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
79 ax.tick_params(which="minor", bottom=False, left=False)
80
81 return im, cbar
82
83
84 def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
85 textcolors=("black", "white"),
86 threshold=None, **textkw):
87 """
88 A function to annotate a heatmap.
89
90 Parameters
91 ----------
92 im
93 The AxesImage to be labeled.
94 data
95 Data used to annotate. If None, the image's data is used. Optional.
96 valfmt
97 The format of the annotations inside the heatmap. This should either
98 use the string format method, e.g. "$ {x:.2f}", or be a
99 `matplotlib.ticker.Formatter`. Optional.
100 textcolors
101 A pair of colors. The first is used for values below a threshold,
102 the second for those above. Optional.
103 threshold
104 Value in data units according to which the colors from textcolors are
105 applied. If None (the default) uses the middle of the colormap as
106 separation. Optional.
107 **kwargs
108 All other arguments are forwarded to each call to `text` used to create
109 the text labels.
110 """
111
112 if not isinstance(data, (list, np.ndarray)):
113 data = im.get_array()
114
115 # Normalize the threshold to the images color range.
116 if threshold is not None:
117 threshold = im.norm(threshold)
118 else:
119 threshold = im.norm(data.max())/2.
120
121 # Set default alignment to center, but allow it to be
122 # overwritten by textkw.
123 kw = dict(horizontalalignment="center",
124 verticalalignment="center")
125 kw.update(textkw)
126
127 # Get the formatter in case a string is supplied
128 if isinstance(valfmt, str):
129 valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
130
131 # Loop over the data and create a `Text` for each "pixel".
132 # Change the text's color depending on the data.
133 texts = []
134 for i in range(data.shape[0]):
135 for j in range(data.shape[1]):
136 kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
137 text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
138 texts.append(text)
139
140 return texts
141
142
143 def plot_binned_statistic(r, ax, title=None, astype=None, **kwargs):
144 """Plots the binned statistic
145
146 Parameters
147 ----------
148 r: the binned statistic
149 ax: the axes to plot
150
151 Returns
152 -------
153 """
154 # Variables
155 rows, cols = r.statistic.shape
156
157 # Compute centers
158 x_center = (r.x_edge[:-1] + r.x_edge[1:]) / 2
159 y_center = (r.y_edge[:-1] + r.y_edge[1:]) / 2
160
161 # Plot heatmap (matplotlib sample, use seaborn instead)
162 im, cbar = heatmap(r.statistic,
163 np.around(x_center, 2), np.around(y_center, 2), ax=ax,
164 cmap="coolwarm", cbarlabel="value [unit]")
165 texts = annotate_heatmap(im, **kwargs)
166
167 # Configure
168 ax.set_aspect('equal', 'box')
169 if title is not None:
170 ax.set_title(title)
171
172 """
173 # Show
174 print("\n\n")
175 print(matrix)
176 print(r.x_edge)
177 print(r.y_edge)
178 print(r.binnumber)
179 print(np.flip(r.statistic, axis=1))
180 """
181
182 def data_manual():
183 """"""
184 # Create random values
185 x = np.array([1, 1, 1, 1, 2, 2, 2, 3, 4])
186 y = np.array([1, 1, 2, 2, 3, 4, 5, 6, 7])
187 z = np.array([1, 9, 9, 1, 2, 2, 2, 3, 4])
188 return x, y, z
189
190 def data_shap():
191 """"""
192 data = pd.read_csv('../../datasets/shap/shap.csv')
193 print(data)
194 return data.timestep, data.shap_values, data.feature_values
195
196
197
198
199 # Load data
200 #x, y, z = data_manual()
201 x, y, z = data_shap()
202
203 # Using np.arange
204 binx = np.arange(0, x.max()+1) + 0.5 # [0.5, 1.5, 2.5, ...., N + 0.5]
205 biny = np.arange(0, y.max()+1) + 0.5 # [0.5, 1.5, 2.5, ...., N + 0.5]
206
207 # Using np.linspace
208 biny = np.linspace(y.min(), y.max(), 10)
209
210 # Manual
211 #binx = np.arange(5) + 0.5
212 #biny = np.arange(8) + 0.5
213
214 # Compute binned statistic (count)
215 r1 = stats.binned_statistic_2d(x=x, y=y, values=None,
216 statistic='count', bins=[binx, biny],
217 expand_binnumbers=True)
218
219 # Compute binned statistic (median)
220 r2 = stats.binned_statistic_2d(x=x, y=y, values=z,
221 statistic='count', bins=[4, 7],
222 expand_binnumbers=False)
223
224 # Compute binned statistic (median)
225 r3 = stats.binned_statistic_2d(x=x, y=y, values=z,
226 statistic='median', bins=[binx, biny],
227 expand_binnumbers=False)
228
229 # Compute binned statistic (median)
230 r4 = stats.binned_statistic_2d(x=x, y=y, values=z,
231 statistic='mean', bins=[binx, biny],
232 expand_binnumbers=False)
233
234
235 # Plot
236 fig, axs = plt.subplots(nrows=2, ncols=2,
237 sharey=True, sharex=True, figsize=(14, 7))
238 plot_binned_statistic(r1, axs[0,0], title='r1 (count)', valfmt="{x:g}")
239 plot_binned_statistic(r2, axs[0,1], title='r2 (count)', valfmt="{x:g}")
240 plot_binned_statistic(r3, axs[1,0], title='r3 (median)')
241 plot_binned_statistic(r3, axs[1,1], title='r4 (mean)')
242
243 # Display
244 plt.tight_layout()
245 plt.show()
Total running time of the script: ( 0 minutes 1.063 seconds)