05. Basic example

  6 # Libraries
  7 import shap
  8 import numpy as np
  9 import pandas as pd
 10 import seaborn as sns
 11
 12 import matplotlib.pyplot as plt
 13 import matplotlib as mpl
 14 import matplotlib.colorbar
 15 import matplotlib.colors
 16 import matplotlib.cm
 17
 18 from mpl_toolkits.axes_grid1 import make_axes_locatable
 19
 20 try:
 21     __file__
 22     TERMINAL = True
 23 except:
 24     TERMINAL = False
 25
 26 # ------------------------
 27 # Methods
 28 # ------------------------
 29 def scalar_colormap(values, cmap, vmin, vmax):
 30     """This method creates a colormap based on values.
 31
 32     Parameters
 33     ----------
 34     values : array-like
 35     The values to create the corresponding colors
 36
 37     cmap : str
 38     The colormap
 39
 40     vmin, vmax : float
 41     The minimum and maximum possible values
 42
 43     Returns
 44     -------
 45     scalar colormap
 46     """
 47     # Create scalar mappable
 48     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
 49     mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
 50     # Get color map
 51     colormap = sns.color_palette([mapper.to_rgba(i) for i in values])
 52     # Return
 53     return colormap, norm
 54
 55
 56 def scalar_palette(values, cmap, vmin, vmax):
 57     """This method creates a colorpalette based on values.
 58
 59     Parameters
 60     ----------
 61     values : array-like
 62     The values to create the corresponding colors
 63
 64     cmap : str
 65     The colormap
 66
 67     vmin, vmax : float
 68     The minimum and maximum possible values
 69
 70     Returns
 71     -------
 72     scalar colormap
 73
 74     """
 75     # Create a matplotlib colormap from name
 76     #cmap = sns.light_palette(cmap, reverse=False, as_cmap=True)
 77     cmap = sns.color_palette(cmap, as_cmap=True)
 78     # Normalize to the range of possible values from df["c"]
 79     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
 80     # Create a color dictionary (value in c : color from colormap)
 81     colors = {}
 82     for cval in values:
 83         colors.update({cval : cmap(norm(cval))})
 84     # Return
 85     return colors, norm
 86
 87
 88 def create_random_shap(samples, timesteps, features):
 89     """Create random LSTM data.
 90
 91     .. note: No need to create the 3D matrix and then reshape to
 92              2D. It would be possible to create directly the 2D
 93              matrix.
 94
 95     Parameters
 96     ----------
 97     samples: int
 98         The number of observations
 99     timesteps: int
100         The number of time steps
101     features: int
102         The number of features
103
104     Returns
105     -------
106     Stacked matrix with the data.
107
108     """
109     # .. note: Either perform a pre-processing step such as
110     #          normalization or generate the features within
111     #          the appropriate interval.
112     # Create dataset
113     x = np.random.randint(low=0, high=100,
114         size=(samples, timesteps, features))
115     y = np.random.randint(low=0, high=2, size=samples).astype(float)
116     i = np.vstack(np.dstack(np.indices((samples, timesteps))))
117
118     # Create DataFrame
119     df = pd.DataFrame(
120         data=np.hstack((i, x.reshape((-1, features)))),
121         columns=['sample', 'timestep'] + \
122                 ['f%s'%j for j in range(features)]
123     )
124
125     df_stack = df.set_index(['sample', 'timestep']).stack()
126     df_stack = df_stack
127     df_stack.name = 'shap_values'
128     df_stack = df_stack.to_frame()
129     df_stack.index.names = ['sample', 'timestep', 'features']
130     df_stack = df_stack.reset_index()
131
132     df_stack['feature_values'] = np.random.randint(
133         low=0, high=100, size=df_stack.shape[0])
134
135     return df_stack
136
137
138 def load_shap_file():
139     data = pd.read_csv('./data/shap.csv')
140     data = data.iloc[: , 1:]
141     #data.timestep = data.timestep - (data.timestep.nunique() - 1)
142     return data

Lets generate and/or load the shap values.

147 # .. note: The right format to use for plotting depends
148 #          on the library we use. The data structure is
149 #          good when using seaborn
150 # Load data
151 data = create_random_shap(10, 6, 4)
152 #data = load_shap_file()
153 #data = data[data['sample'] < 100]
154
155 shap_values = pd.pivot_table(data,
156         values='shap_values',
157         index=['sample', 'timestep'],
158         columns=['features'])
159
160 feature_values = pd.pivot_table(data,
161         values='feature_values',
162         index=['sample', 'timestep'],
163         columns=['features'])
164
165 # Show
166 if TERMINAL:
167     print("\nShow:")
168     print(data)
169     print(shap_values)
170     print(feature_values)

Let’s see how data looks like

174 data.head(10)
sample timestep features shap_values feature_values
0 0 0 f0 51 25
1 0 0 f1 39 44
2 0 0 f2 73 58
3 0 0 f3 41 52
4 0 1 f0 28 91
5 0 1 f1 30 50
6 0 1 f2 32 29
7 0 1 f3 91 32
8 0 2 f0 26 54
9 0 2 f1 60 73


Let’s see how shap_values looks like

178 shap_values.iloc[:10, :5]
features f0 f1 f2 f3
sample timestep
0 0 51 39 73 41
1 28 30 32 91
2 26 60 6 25
3 80 42 59 29
4 25 40 64 6
5 58 19 84 91
1 0 68 92 75 12
1 94 60 29 92
2 12 13 58 18
3 84 15 80 55


Let’s see how feature_values looks like

182 feature_values.iloc[:10, :5]
features f0 f1 f2 f3
sample timestep
0 0 25 44 58 52
1 91 50 29 32
2 54 73 56 30
3 93 17 45 44
4 34 60 40 58
5 89 12 55 82
1 0 52 9 80 58
1 24 47 5 28
2 11 53 29 25
3 36 60 57 2


Display using shap.summary_plot

The first option is to use the shap library to plot the results.

191 # Let's define/extract some useful variables.
192 N = 4                                                       # max loops filter
193 TIMESTEPS = len(shap_values.index.unique(level='timestep')) # number of timesteps
194 SAMPLES = len(shap_values.index.unique(level='sample'))     # number of samples
195
196 shap_min = data.shap_values.min()
197 shap_max = data.shap_values.max()

Now, let’s display the shap values for all features in each timestep.

202 # For each timestep (visualise all features)
203 for i, step in enumerate(range(TIMESTEPS)[:N]):
204     # Show
205     #print('%2d. %s' % (i, step))
206
207     # .. note: First option (commented) is only necessary if we work
208     #          with a numpy array. However, since we are using a DataFrame
209     #          with the timestep, we can index by that index level.
210     # Compute indices
211     #indice = np.arange(SAMPLES)*TIMESTEPS + step
212     indice = shap_values.index.get_level_values('timestep') == i
213
214     # Create auxiliary matrices
215     shap_aux = shap_values.iloc[indice]
216     feat_aux = feature_values.iloc[indice]
217
218     # Display
219     plt.figure()
220     plt.title("Timestep: %s" % i)
221     shap.summary_plot(shap_aux.to_numpy(), feat_aux, show=False)
222     plt.xlim(shap_min, shap_max)
  • Timestep: 0
  • Timestep: 1
  • Timestep: 2
  • Timestep: 3

Now, let’s display the shap values for all timesteps of each feature.

227 # For each feature (visualise all time-steps)
228 for i, f in enumerate(shap_values.columns[:N]):
229     # Show
230     #print('%2d. %s' % (i, f))
231
232     # Create auxiliary matrices (select feature and reshape)
233     shap_aux = shap_values.iloc[:, i] \
234         .to_numpy().reshape(-1, TIMESTEPS)
235     feat_aux = feature_values.iloc[:, i] \
236         .to_numpy().reshape(-1, TIMESTEPS)
237     feat_aux = pd.DataFrame(feat_aux,
238         columns=['timestep %s'%j for j in range(TIMESTEPS)]
239     )
240
241     # Show
242     plt.figure()
243     plt.title("Feature: %s" % f)
244     shap.summary_plot(shap_aux, feat_aux, sort=False, show=False)
245     plt.xlim(shap_min, shap_max)
  • Feature: f0
  • Feature: f1
  • Feature: f2
  • Feature: f3

Note

If y-axis represents timesteps the sort parameter in the summary_plot function is set to False.

Display using sns.stripplot

Warning

This method seems to be quite slow.

Let’s display the shap values for each feature and all time steps. In contrast to the previous example, the timesteps are now displayed on the x-axis and the y-axis contains the shap values.

262 def add_colorbar(fig, cmap, norm):
263     """"""
264     divider = make_axes_locatable(plt.gca())
265     ax_cb = divider.new_horizontal(size="5%", pad=0.05)
266     fig.add_axes(ax_cb)
267     cb1 = matplotlib.colorbar.ColorbarBase(ax_cb,
268          cmap=cmap, norm=norm, orientation='vertical')
269
270
271 # Loop
272 for i, (name, df) in enumerate(data.groupby('features')):
273
274     # Get colormap
275     values = df.feature_values
276     cmap, norm = scalar_palette(values=values, cmap='coolwarm',
277         vmin=values.min(), vmax=values.max())
278
279     print(df)
280
281     # Display
282     fig, ax = plt.subplots()
283     ax = sns.stripplot(x='timestep',
284                        y='shap_values',
285                        hue='feature_values',
286                        palette=cmap,
287                        data=df,
288                        ax=ax)
289
290     # Needed for older matplotlib versions
291     cmap = matplotlib.cm.get_cmap('coolwarm')
292
293     # Configure axes
294     plt.title(name)
295     plt.legend([], [], frameon=False)
296     ax.invert_xaxis()
297     add_colorbar(plt.gcf(), cmap, norm)
298
299     # End
300     if int(i) > N:
301         break
302
303 # Show
304 plt.show()
  • f0
  • f1
  • f2
  • f3

Out:

     sample  timestep features  shap_values  feature_values
0         0         0       f0           51              25
4         0         1       f0           28              91
8         0         2       f0           26              54
12        0         3       f0           80              93
16        0         4       f0           25              34
20        0         5       f0           58              89
24        1         0       f0           68              52
28        1         1       f0           94              24
32        1         2       f0           12              11
36        1         3       f0           84              36
40        1         4       f0           23              65
44        1         5       f0           65              40
48        2         0       f0           40              23
52        2         1       f0           44               7
56        2         2       f0           55              61
60        2         3       f0           91              96
64        2         4       f0            6              77
68        2         5       f0           90              18
72        3         0       f0           32              55
76        3         1       f0           86              78
80        3         2       f0           49              75
84        3         3       f0           91              48
88        3         4       f0           19              69
92        3         5       f0           82               1
96        4         0       f0           56              13
100       4         1       f0           50              93
104       4         2       f0           76              62
108       4         3       f0           56               8
112       4         4       f0           86              86
116       4         5       f0           86               5
120       5         0       f0           26              59
124       5         1       f0           97              49
128       5         2       f0           13              99
132       5         3       f0            9              14
136       5         4       f0           10              60
140       5         5       f0           96              85
144       6         0       f0           31              81
148       6         1       f0           38              89
152       6         2       f0           29              63
156       6         3       f0           87              32
160       6         4       f0           68              51
164       6         5       f0           56              68
168       7         0       f0           86              66
172       7         1       f0           67              20
176       7         2       f0           20              33
180       7         3       f0           53              61
184       7         4       f0           82              55
188       7         5       f0           16              97
192       8         0       f0           63              38
196       8         1       f0           63              47
200       8         2       f0           94              99
204       8         3       f0           89              50
208       8         4       f0           62              79
212       8         5       f0            1              41
216       9         0       f0            1              16
220       9         1       f0           87              71
224       9         2       f0           83              31
228       9         3       f0           45              86
232       9         4       f0           24              38
236       9         5       f0           57               7
     sample  timestep features  shap_values  feature_values
1         0         0       f1           39              44
5         0         1       f1           30              50
9         0         2       f1           60              73
13        0         3       f1           42              17
17        0         4       f1           40              60
21        0         5       f1           19              12
25        1         0       f1           92               9
29        1         1       f1           60              47
33        1         2       f1           13              53
37        1         3       f1           15              60
41        1         4       f1           34              48
45        1         5       f1           93              79
49        2         0       f1           38               6
53        2         1       f1           37              29
57        2         2       f1           58              27
61        2         3       f1           96              64
65        2         4       f1           77              95
69        2         5       f1           37              11
73        3         0       f1           39              36
77        3         1       f1           54              28
81        3         2       f1           86              41
85        3         3       f1           62              18
89        3         4       f1           43              33
93        3         5       f1           36              92
97        4         0       f1           44              63
101       4         1       f1           70              82
105       4         2       f1           78              48
109       4         3       f1           33              95
113       4         4       f1           71              49
117       4         5       f1           41               5
121       5         0       f1           81              46
125       5         1       f1           43              40
129       5         2       f1           50              45
133       5         3       f1           79              21
137       5         4       f1           49              56
141       5         5       f1           45              73
145       6         0       f1           44              96
149       6         1       f1            7              19
153       6         2       f1           44              49
157       6         3       f1           14              51
161       6         4       f1           82              25
165       6         5       f1           28              48
169       7         0       f1           89              51
173       7         1       f1           33              59
177       7         2       f1           24               2
181       7         3       f1           79              39
185       7         4       f1           65              42
189       7         5       f1           60              78
193       8         0       f1           81              44
197       8         1       f1           90              69
201       8         2       f1           88              79
205       8         3       f1           65              41
209       8         4       f1           96              71
213       8         5       f1           24              29
217       9         0       f1           78              12
221       9         1       f1           54               4
225       9         2       f1           81               7
229       9         3       f1           13              53
233       9         4       f1           81              87
237       9         5       f1           86              29
     sample  timestep features  shap_values  feature_values
2         0         0       f2           73              58
6         0         1       f2           32              29
10        0         2       f2            6              56
14        0         3       f2           59              45
18        0         4       f2           64              40
22        0         5       f2           84              55
26        1         0       f2           75              80
30        1         1       f2           29               5
34        1         2       f2           58              29
38        1         3       f2           80              57
42        1         4       f2           77              83
46        1         5       f2           60              65
50        2         0       f2           89              24
54        2         1       f2            5              91
58        2         2       f2           58              62
62        2         3       f2           16              43
66        2         4       f2           59              35
70        2         5       f2           59              19
74        3         0       f2           55               0
78        3         1       f2           55              49
82        3         2       f2           33              63
86        3         3       f2           94              52
90        3         4       f2           24              78
94        3         5       f2           93              29
98        4         0       f2           46              24
102       4         1       f2           98              50
106       4         2       f2           12              60
110       4         3       f2           44              35
114       4         4       f2           34              74
118       4         5       f2           40              37
122       5         0       f2           88              75
126       5         1       f2           65               3
130       5         2       f2           68              96
134       5         3       f2           14              59
138       5         4       f2           82              71
142       5         5       f2           75              83
146       6         0       f2           27              96
150       6         1       f2           53              51
154       6         2       f2           98              72
158       6         3       f2           34              17
162       6         4       f2           14              62
166       6         5       f2           96              41
170       7         0       f2           22              10
174       7         1       f2           64               8
178       7         2       f2           59              88
182       7         3       f2           37              95
186       7         4       f2           51              82
190       7         5       f2           15              19
194       8         0       f2           29               5
198       8         1       f2           41              68
202       8         2       f2           52              48
206       8         3       f2            1              31
210       8         4       f2           38              84
214       8         5       f2           54              69
218       9         0       f2           89              93
222       9         1       f2           94              78
226       9         2       f2           18              58
230       9         3       f2           48              74
234       9         4       f2           24              97
238       9         5       f2           91              89
     sample  timestep features  shap_values  feature_values
3         0         0       f3           41              52
7         0         1       f3           91              32
11        0         2       f3           25              30
15        0         3       f3           29              44
19        0         4       f3            6              58
23        0         5       f3           91              82
27        1         0       f3           12              58
31        1         1       f3           92              28
35        1         2       f3           18              25
39        1         3       f3           55               2
43        1         4       f3           23              71
47        1         5       f3           20              90
51        2         0       f3           45              95
55        2         1       f3           25              35
59        2         2       f3           30              75
63        2         3       f3           42              37
67        2         4       f3           24              89
71        2         5       f3           23              72
75        3         0       f3           93              49
79        3         1       f3           54              10
83        3         2       f3           58              36
87        3         3       f3           39              10
91        3         4       f3           10              25
95        3         5       f3           33              67
99        4         0       f3           85              64
103       4         1       f3            7              40
107       4         2       f3           49              27
111       4         3       f3           60              37
115       4         4       f3           86              68
119       4         5       f3           71              41
123       5         0       f3           15              60
127       5         1       f3           44               0
131       5         2       f3           53              53
135       5         3       f3           56              69
139       5         4       f3           45              94
143       5         5       f3           19              15
147       6         0       f3           70              41
151       6         1       f3           73              42
155       6         2       f3           91              87
159       6         3       f3           20              92
163       6         4       f3           40              54
167       6         5       f3           90              11
171       7         0       f3           92              77
175       7         1       f3           84              19
179       7         2       f3           76              14
183       7         3       f3           69              69
187       7         4       f3           97              99
191       7         5       f3           13              56
195       8         0       f3           12              41
199       8         1       f3           51              56
203       8         2       f3           85              14
207       8         3       f3           59              93
211       8         4       f3           74              62
215       8         5       f3           32               0
219       9         0       f3           32              81
223       9         1       f3           43              56
227       9         2       f3           57              31
231       9         3       f3            6              46
235       9         4       f3           87              31
239       9         5       f3           40              98

Display using sns.swarmplot

Let’s display the shap values for each timestep.

318 # Loop
319 for i, (name, df) in enumerate(data.groupby('features')):
320
321     # Get colormap
322     values = df.feature_values
323     cmap, norm = scalar_palette(values=values, cmap='coolwarm',
324         vmin=values.min(), vmax=values.max())
325
326     # Display
327     fig, ax = plt.subplots()
328     ax = sns.swarmplot(x='timestep',
329                        y='shap_values',
330                        hue='feature_values',
331                        palette=cmap,
332                        data=df,
333                        size=2,
334                        ax=ax)
335
336     # Needed for older matplotlib versions
337     cmap = matplotlib.cm.get_cmap('coolwarm')
338
339     # Configure axes
340     plt.title(name)
341     plt.legend([], [], frameon=False)
342     ax.invert_xaxis()
343     add_colorbar(plt.gcf(), cmap, norm)
344
345     # End
346     if int(i) > N:
347         break
348
349 # Show
350 plt.show()
351
352
353
354
355
356
357
358 """
359 sns.set_theme(style="ticks")
360
361 # Create a dataset with many short random walks
362 rs = np.random.RandomState(4)
363 pos = rs.randint(-1, 2, (20, 5)).cumsum(axis=1)
364 pos -= pos[:, 0, np.newaxis]
365 step = np.tile(range(5), 20)
366 walk = np.repeat(range(20), 5)
367 df = pd.DataFrame(np.c_[pos.flat, step, walk],
368                   columns=["position", "step", "walk"])
369 # Initialize a grid of plots with an Axes for each walk
370 #grid = sns.FacetGrid(df_stack, col="walk", hue="f", palette="tab20c",
371 #                     col_wrap=4, height=1.5)
372
373 grid = sns.FacetGrid(df_stack, hue="f",
374     palette="tab20c", height=1.5)
375
376 # Draw a horizontal line to show the starting point
377 grid.refline(y=0, linestyle=":")
378
379 # Draw a line plot to show the trajectory of each random walk
380 grid.map(plt.plot, "t", "value", marker="o")
381
382 # Adjust the tick positions and labels
383 grid.set(xticks=np.arange(5), yticks=[-3, 3],
384          xlim=(-.5, 4.5), ylim=(-3.5, 3.5))
385
386 # Adjust the arrangement of the plots
387 grid.fig.tight_layout(w_pad=1)
388
389 """
390
391
392 #plt.show()
  • f0
  • f1
  • f2
  • f3

Out:

'\nsns.set_theme(style="ticks")\n\n# Create a dataset with many short random walks\nrs = np.random.RandomState(4)\npos = rs.randint(-1, 2, (20, 5)).cumsum(axis=1)\npos -= pos[:, 0, np.newaxis]\nstep = np.tile(range(5), 20)\nwalk = np.repeat(range(20), 5)\ndf = pd.DataFrame(np.c_[pos.flat, step, walk],\n                  columns=["position", "step", "walk"])\n# Initialize a grid of plots with an Axes for each walk\n#grid = sns.FacetGrid(df_stack, col="walk", hue="f", palette="tab20c",\n#                     col_wrap=4, height=1.5)\n\ngrid = sns.FacetGrid(df_stack, hue="f",\n    palette="tab20c", height=1.5)\n\n# Draw a horizontal line to show the starting point\ngrid.refline(y=0, linestyle=":")\n\n# Draw a line plot to show the trajectory of each random walk\ngrid.map(plt.plot, "t", "value", marker="o")\n\n# Adjust the tick positions and labels\ngrid.set(xticks=np.arange(5), yticks=[-3, 3],\n         xlim=(-.5, 4.5), ylim=(-3.5, 3.5))\n\n# Adjust the arrangement of the plots\ngrid.fig.tight_layout(w_pad=1)\n\n'

Display using sns.FacetGrid

399 #g = sns.FacetGrid(df_stack, col="f", hue='original')
400 #g.map(sns.swarmplot, "t", "value", alpha=.7)
401 #g.add_legend()

Display using shap.beeswarm

409 # REF: https://github.com/slundberg/shap/blob/master/shap/plots/_beeswarm.py
410 #
411 # .. note: It needs a kernel explainer, and while it works with
412 #          common kernels (plot_main07.py) it does not work with
413 #          the DeepKernel for some reason (mask related).

Total running time of the script: ( 0 minutes 4.630 seconds)

Gallery generated by Sphinx-Gallery