Note
Click here to download the full example code
05. Basic example
6 # Libraries
7 import shap
8 import numpy as np
9 import pandas as pd
10 import seaborn as sns
11
12 import matplotlib.pyplot as plt
13 import matplotlib as mpl
14 import matplotlib.colorbar
15 import matplotlib.colors
16 import matplotlib.cm
17
18 from mpl_toolkits.axes_grid1 import make_axes_locatable
19
20 try:
21 __file__
22 TERMINAL = True
23 except:
24 TERMINAL = False
25
26 # ------------------------
27 # Methods
28 # ------------------------
29 def scalar_colormap(values, cmap, vmin, vmax):
30 """This method creates a colormap based on values.
31
32 Parameters
33 ----------
34 values : array-like
35 The values to create the corresponding colors
36
37 cmap : str
38 The colormap
39
40 vmin, vmax : float
41 The minimum and maximum possible values
42
43 Returns
44 -------
45 scalar colormap
46 """
47 # Create scalar mappable
48 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
49 mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
50 # Get color map
51 colormap = sns.color_palette([mapper.to_rgba(i) for i in values])
52 # Return
53 return colormap, norm
54
55
56 def scalar_palette(values, cmap, vmin, vmax):
57 """This method creates a colorpalette based on values.
58
59 Parameters
60 ----------
61 values : array-like
62 The values to create the corresponding colors
63
64 cmap : str
65 The colormap
66
67 vmin, vmax : float
68 The minimum and maximum possible values
69
70 Returns
71 -------
72 scalar colormap
73
74 """
75 # Create a matplotlib colormap from name
76 #cmap = sns.light_palette(cmap, reverse=False, as_cmap=True)
77 cmap = sns.color_palette(cmap, as_cmap=True)
78 # Normalize to the range of possible values from df["c"]
79 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
80 # Create a color dictionary (value in c : color from colormap)
81 colors = {}
82 for cval in values:
83 colors.update({cval : cmap(norm(cval))})
84 # Return
85 return colors, norm
86
87
88 def create_random_shap(samples, timesteps, features):
89 """Create random LSTM data.
90
91 .. note: No need to create the 3D matrix and then reshape to
92 2D. It would be possible to create directly the 2D
93 matrix.
94
95 Parameters
96 ----------
97 samples: int
98 The number of observations
99 timesteps: int
100 The number of time steps
101 features: int
102 The number of features
103
104 Returns
105 -------
106 Stacked matrix with the data.
107
108 """
109 # .. note: Either perform a pre-processing step such as
110 # normalization or generate the features within
111 # the appropriate interval.
112 # Create dataset
113 x = np.random.randint(low=0, high=100,
114 size=(samples, timesteps, features))
115 y = np.random.randint(low=0, high=2, size=samples).astype(float)
116 i = np.vstack(np.dstack(np.indices((samples, timesteps))))
117
118 # Create DataFrame
119 df = pd.DataFrame(
120 data=np.hstack((i, x.reshape((-1, features)))),
121 columns=['sample', 'timestep'] + \
122 ['f%s'%j for j in range(features)]
123 )
124
125 df_stack = df.set_index(['sample', 'timestep']).stack()
126 df_stack = df_stack
127 df_stack.name = 'shap_values'
128 df_stack = df_stack.to_frame()
129 df_stack.index.names = ['sample', 'timestep', 'features']
130 df_stack = df_stack.reset_index()
131
132 df_stack['feature_values'] = np.random.randint(
133 low=0, high=100, size=df_stack.shape[0])
134
135 return df_stack
136
137
138 def load_shap_file():
139 data = pd.read_csv('./data/shap.csv')
140 data = data.iloc[: , 1:]
141 #data.timestep = data.timestep - (data.timestep.nunique() - 1)
142 return data
Lets generate and/or load the shap values.
147 # .. note: The right format to use for plotting depends
148 # on the library we use. The data structure is
149 # good when using seaborn
150 # Load data
151 data = create_random_shap(10, 6, 4)
152 #data = load_shap_file()
153 #data = data[data['sample'] < 100]
154
155 shap_values = pd.pivot_table(data,
156 values='shap_values',
157 index=['sample', 'timestep'],
158 columns=['features'])
159
160 feature_values = pd.pivot_table(data,
161 values='feature_values',
162 index=['sample', 'timestep'],
163 columns=['features'])
164
165 # Show
166 if TERMINAL:
167 print("\nShow:")
168 print(data)
169 print(shap_values)
170 print(feature_values)
Let’s see how data looks like
174 data.head(10)
Let’s see how shap_values looks like
178 shap_values.iloc[:10, :5]
Let’s see how feature_values looks like
182 feature_values.iloc[:10, :5]
Display using shap.summary_plot
The first option is to use the shap
library to plot the results.
191 # Let's define/extract some useful variables.
192 N = 4 # max loops filter
193 TIMESTEPS = len(shap_values.index.unique(level='timestep')) # number of timesteps
194 SAMPLES = len(shap_values.index.unique(level='sample')) # number of samples
195
196 shap_min = data.shap_values.min()
197 shap_max = data.shap_values.max()
Now, let’s display the shap values for all features in each timestep.
202 # For each timestep (visualise all features)
203 for i, step in enumerate(range(TIMESTEPS)[:N]):
204 # Show
205 #print('%2d. %s' % (i, step))
206
207 # .. note: First option (commented) is only necessary if we work
208 # with a numpy array. However, since we are using a DataFrame
209 # with the timestep, we can index by that index level.
210 # Compute indices
211 #indice = np.arange(SAMPLES)*TIMESTEPS + step
212 indice = shap_values.index.get_level_values('timestep') == i
213
214 # Create auxiliary matrices
215 shap_aux = shap_values.iloc[indice]
216 feat_aux = feature_values.iloc[indice]
217
218 # Display
219 plt.figure()
220 plt.title("Timestep: %s" % i)
221 shap.summary_plot(shap_aux.to_numpy(), feat_aux, show=False)
222 plt.xlim(shap_min, shap_max)
Now, let’s display the shap values for all timesteps of each feature.
227 # For each feature (visualise all time-steps)
228 for i, f in enumerate(shap_values.columns[:N]):
229 # Show
230 #print('%2d. %s' % (i, f))
231
232 # Create auxiliary matrices (select feature and reshape)
233 shap_aux = shap_values.iloc[:, i] \
234 .to_numpy().reshape(-1, TIMESTEPS)
235 feat_aux = feature_values.iloc[:, i] \
236 .to_numpy().reshape(-1, TIMESTEPS)
237 feat_aux = pd.DataFrame(feat_aux,
238 columns=['timestep %s'%j for j in range(TIMESTEPS)]
239 )
240
241 # Show
242 plt.figure()
243 plt.title("Feature: %s" % f)
244 shap.summary_plot(shap_aux, feat_aux, sort=False, show=False)
245 plt.xlim(shap_min, shap_max)
Note
If y-axis represents timesteps the sort
parameter
in the summary_plot
function is set to False.
Display using sns.stripplot
Warning
This method seems to be quite slow.
Let’s display the shap values for each feature and all time steps. In contrast to the previous example, the timesteps are now displayed on the x-axis and the y-axis contains the shap values.
262 def add_colorbar(fig, cmap, norm):
263 """"""
264 divider = make_axes_locatable(plt.gca())
265 ax_cb = divider.new_horizontal(size="5%", pad=0.05)
266 fig.add_axes(ax_cb)
267 cb1 = matplotlib.colorbar.ColorbarBase(ax_cb,
268 cmap=cmap, norm=norm, orientation='vertical')
269
270
271 # Loop
272 for i, (name, df) in enumerate(data.groupby('features')):
273
274 # Get colormap
275 values = df.feature_values
276 cmap, norm = scalar_palette(values=values, cmap='coolwarm',
277 vmin=values.min(), vmax=values.max())
278
279 print(df)
280
281 # Display
282 fig, ax = plt.subplots()
283 ax = sns.stripplot(x='timestep',
284 y='shap_values',
285 hue='feature_values',
286 palette=cmap,
287 data=df,
288 ax=ax)
289
290 # Needed for older matplotlib versions
291 cmap = matplotlib.cm.get_cmap('coolwarm')
292
293 # Configure axes
294 plt.title(name)
295 plt.legend([], [], frameon=False)
296 ax.invert_xaxis()
297 add_colorbar(plt.gcf(), cmap, norm)
298
299 # End
300 if int(i) > N:
301 break
302
303 # Show
304 plt.show()
Out:
sample timestep features shap_values feature_values
0 0 0 f0 51 25
4 0 1 f0 28 91
8 0 2 f0 26 54
12 0 3 f0 80 93
16 0 4 f0 25 34
20 0 5 f0 58 89
24 1 0 f0 68 52
28 1 1 f0 94 24
32 1 2 f0 12 11
36 1 3 f0 84 36
40 1 4 f0 23 65
44 1 5 f0 65 40
48 2 0 f0 40 23
52 2 1 f0 44 7
56 2 2 f0 55 61
60 2 3 f0 91 96
64 2 4 f0 6 77
68 2 5 f0 90 18
72 3 0 f0 32 55
76 3 1 f0 86 78
80 3 2 f0 49 75
84 3 3 f0 91 48
88 3 4 f0 19 69
92 3 5 f0 82 1
96 4 0 f0 56 13
100 4 1 f0 50 93
104 4 2 f0 76 62
108 4 3 f0 56 8
112 4 4 f0 86 86
116 4 5 f0 86 5
120 5 0 f0 26 59
124 5 1 f0 97 49
128 5 2 f0 13 99
132 5 3 f0 9 14
136 5 4 f0 10 60
140 5 5 f0 96 85
144 6 0 f0 31 81
148 6 1 f0 38 89
152 6 2 f0 29 63
156 6 3 f0 87 32
160 6 4 f0 68 51
164 6 5 f0 56 68
168 7 0 f0 86 66
172 7 1 f0 67 20
176 7 2 f0 20 33
180 7 3 f0 53 61
184 7 4 f0 82 55
188 7 5 f0 16 97
192 8 0 f0 63 38
196 8 1 f0 63 47
200 8 2 f0 94 99
204 8 3 f0 89 50
208 8 4 f0 62 79
212 8 5 f0 1 41
216 9 0 f0 1 16
220 9 1 f0 87 71
224 9 2 f0 83 31
228 9 3 f0 45 86
232 9 4 f0 24 38
236 9 5 f0 57 7
sample timestep features shap_values feature_values
1 0 0 f1 39 44
5 0 1 f1 30 50
9 0 2 f1 60 73
13 0 3 f1 42 17
17 0 4 f1 40 60
21 0 5 f1 19 12
25 1 0 f1 92 9
29 1 1 f1 60 47
33 1 2 f1 13 53
37 1 3 f1 15 60
41 1 4 f1 34 48
45 1 5 f1 93 79
49 2 0 f1 38 6
53 2 1 f1 37 29
57 2 2 f1 58 27
61 2 3 f1 96 64
65 2 4 f1 77 95
69 2 5 f1 37 11
73 3 0 f1 39 36
77 3 1 f1 54 28
81 3 2 f1 86 41
85 3 3 f1 62 18
89 3 4 f1 43 33
93 3 5 f1 36 92
97 4 0 f1 44 63
101 4 1 f1 70 82
105 4 2 f1 78 48
109 4 3 f1 33 95
113 4 4 f1 71 49
117 4 5 f1 41 5
121 5 0 f1 81 46
125 5 1 f1 43 40
129 5 2 f1 50 45
133 5 3 f1 79 21
137 5 4 f1 49 56
141 5 5 f1 45 73
145 6 0 f1 44 96
149 6 1 f1 7 19
153 6 2 f1 44 49
157 6 3 f1 14 51
161 6 4 f1 82 25
165 6 5 f1 28 48
169 7 0 f1 89 51
173 7 1 f1 33 59
177 7 2 f1 24 2
181 7 3 f1 79 39
185 7 4 f1 65 42
189 7 5 f1 60 78
193 8 0 f1 81 44
197 8 1 f1 90 69
201 8 2 f1 88 79
205 8 3 f1 65 41
209 8 4 f1 96 71
213 8 5 f1 24 29
217 9 0 f1 78 12
221 9 1 f1 54 4
225 9 2 f1 81 7
229 9 3 f1 13 53
233 9 4 f1 81 87
237 9 5 f1 86 29
sample timestep features shap_values feature_values
2 0 0 f2 73 58
6 0 1 f2 32 29
10 0 2 f2 6 56
14 0 3 f2 59 45
18 0 4 f2 64 40
22 0 5 f2 84 55
26 1 0 f2 75 80
30 1 1 f2 29 5
34 1 2 f2 58 29
38 1 3 f2 80 57
42 1 4 f2 77 83
46 1 5 f2 60 65
50 2 0 f2 89 24
54 2 1 f2 5 91
58 2 2 f2 58 62
62 2 3 f2 16 43
66 2 4 f2 59 35
70 2 5 f2 59 19
74 3 0 f2 55 0
78 3 1 f2 55 49
82 3 2 f2 33 63
86 3 3 f2 94 52
90 3 4 f2 24 78
94 3 5 f2 93 29
98 4 0 f2 46 24
102 4 1 f2 98 50
106 4 2 f2 12 60
110 4 3 f2 44 35
114 4 4 f2 34 74
118 4 5 f2 40 37
122 5 0 f2 88 75
126 5 1 f2 65 3
130 5 2 f2 68 96
134 5 3 f2 14 59
138 5 4 f2 82 71
142 5 5 f2 75 83
146 6 0 f2 27 96
150 6 1 f2 53 51
154 6 2 f2 98 72
158 6 3 f2 34 17
162 6 4 f2 14 62
166 6 5 f2 96 41
170 7 0 f2 22 10
174 7 1 f2 64 8
178 7 2 f2 59 88
182 7 3 f2 37 95
186 7 4 f2 51 82
190 7 5 f2 15 19
194 8 0 f2 29 5
198 8 1 f2 41 68
202 8 2 f2 52 48
206 8 3 f2 1 31
210 8 4 f2 38 84
214 8 5 f2 54 69
218 9 0 f2 89 93
222 9 1 f2 94 78
226 9 2 f2 18 58
230 9 3 f2 48 74
234 9 4 f2 24 97
238 9 5 f2 91 89
sample timestep features shap_values feature_values
3 0 0 f3 41 52
7 0 1 f3 91 32
11 0 2 f3 25 30
15 0 3 f3 29 44
19 0 4 f3 6 58
23 0 5 f3 91 82
27 1 0 f3 12 58
31 1 1 f3 92 28
35 1 2 f3 18 25
39 1 3 f3 55 2
43 1 4 f3 23 71
47 1 5 f3 20 90
51 2 0 f3 45 95
55 2 1 f3 25 35
59 2 2 f3 30 75
63 2 3 f3 42 37
67 2 4 f3 24 89
71 2 5 f3 23 72
75 3 0 f3 93 49
79 3 1 f3 54 10
83 3 2 f3 58 36
87 3 3 f3 39 10
91 3 4 f3 10 25
95 3 5 f3 33 67
99 4 0 f3 85 64
103 4 1 f3 7 40
107 4 2 f3 49 27
111 4 3 f3 60 37
115 4 4 f3 86 68
119 4 5 f3 71 41
123 5 0 f3 15 60
127 5 1 f3 44 0
131 5 2 f3 53 53
135 5 3 f3 56 69
139 5 4 f3 45 94
143 5 5 f3 19 15
147 6 0 f3 70 41
151 6 1 f3 73 42
155 6 2 f3 91 87
159 6 3 f3 20 92
163 6 4 f3 40 54
167 6 5 f3 90 11
171 7 0 f3 92 77
175 7 1 f3 84 19
179 7 2 f3 76 14
183 7 3 f3 69 69
187 7 4 f3 97 99
191 7 5 f3 13 56
195 8 0 f3 12 41
199 8 1 f3 51 56
203 8 2 f3 85 14
207 8 3 f3 59 93
211 8 4 f3 74 62
215 8 5 f3 32 0
219 9 0 f3 32 81
223 9 1 f3 43 56
227 9 2 f3 57 31
231 9 3 f3 6 46
235 9 4 f3 87 31
239 9 5 f3 40 98
Display using sns.swarmplot
Let’s display the shap values for each timestep.
318 # Loop
319 for i, (name, df) in enumerate(data.groupby('features')):
320
321 # Get colormap
322 values = df.feature_values
323 cmap, norm = scalar_palette(values=values, cmap='coolwarm',
324 vmin=values.min(), vmax=values.max())
325
326 # Display
327 fig, ax = plt.subplots()
328 ax = sns.swarmplot(x='timestep',
329 y='shap_values',
330 hue='feature_values',
331 palette=cmap,
332 data=df,
333 size=2,
334 ax=ax)
335
336 # Needed for older matplotlib versions
337 cmap = matplotlib.cm.get_cmap('coolwarm')
338
339 # Configure axes
340 plt.title(name)
341 plt.legend([], [], frameon=False)
342 ax.invert_xaxis()
343 add_colorbar(plt.gcf(), cmap, norm)
344
345 # End
346 if int(i) > N:
347 break
348
349 # Show
350 plt.show()
351
352
353
354
355
356
357
358 """
359 sns.set_theme(style="ticks")
360
361 # Create a dataset with many short random walks
362 rs = np.random.RandomState(4)
363 pos = rs.randint(-1, 2, (20, 5)).cumsum(axis=1)
364 pos -= pos[:, 0, np.newaxis]
365 step = np.tile(range(5), 20)
366 walk = np.repeat(range(20), 5)
367 df = pd.DataFrame(np.c_[pos.flat, step, walk],
368 columns=["position", "step", "walk"])
369 # Initialize a grid of plots with an Axes for each walk
370 #grid = sns.FacetGrid(df_stack, col="walk", hue="f", palette="tab20c",
371 # col_wrap=4, height=1.5)
372
373 grid = sns.FacetGrid(df_stack, hue="f",
374 palette="tab20c", height=1.5)
375
376 # Draw a horizontal line to show the starting point
377 grid.refline(y=0, linestyle=":")
378
379 # Draw a line plot to show the trajectory of each random walk
380 grid.map(plt.plot, "t", "value", marker="o")
381
382 # Adjust the tick positions and labels
383 grid.set(xticks=np.arange(5), yticks=[-3, 3],
384 xlim=(-.5, 4.5), ylim=(-3.5, 3.5))
385
386 # Adjust the arrangement of the plots
387 grid.fig.tight_layout(w_pad=1)
388
389 """
390
391
392 #plt.show()
Out:
'\nsns.set_theme(style="ticks")\n\n# Create a dataset with many short random walks\nrs = np.random.RandomState(4)\npos = rs.randint(-1, 2, (20, 5)).cumsum(axis=1)\npos -= pos[:, 0, np.newaxis]\nstep = np.tile(range(5), 20)\nwalk = np.repeat(range(20), 5)\ndf = pd.DataFrame(np.c_[pos.flat, step, walk],\n columns=["position", "step", "walk"])\n# Initialize a grid of plots with an Axes for each walk\n#grid = sns.FacetGrid(df_stack, col="walk", hue="f", palette="tab20c",\n# col_wrap=4, height=1.5)\n\ngrid = sns.FacetGrid(df_stack, hue="f",\n palette="tab20c", height=1.5)\n\n# Draw a horizontal line to show the starting point\ngrid.refline(y=0, linestyle=":")\n\n# Draw a line plot to show the trajectory of each random walk\ngrid.map(plt.plot, "t", "value", marker="o")\n\n# Adjust the tick positions and labels\ngrid.set(xticks=np.arange(5), yticks=[-3, 3],\n xlim=(-.5, 4.5), ylim=(-3.5, 3.5))\n\n# Adjust the arrangement of the plots\ngrid.fig.tight_layout(w_pad=1)\n\n'
Display using sns.FacetGrid
399 #g = sns.FacetGrid(df_stack, col="f", hue='original')
400 #g.map(sns.swarmplot, "t", "value", alpha=.7)
401 #g.add_legend()
Display using shap.beeswarm
409 # REF: https://github.com/slundberg/shap/blob/master/shap/plots/_beeswarm.py
410 #
411 # .. note: It needs a kernel explainer, and while it works with
412 # common kernels (plot_main07.py) it does not work with
413 # the DeepKernel for some reason (mask related).
Total running time of the script: ( 0 minutes 4.630 seconds)