Note
Click here to download the full example code
05. Stripplot
Warning
This method is quite slow.
9 # Libraries
10 import shap
11 import numpy as np
12 import pandas as pd
13 import seaborn as sns
14
15 import matplotlib.pyplot as plt
16 import matplotlib as mpl
17 import matplotlib.colorbar
18 import matplotlib.colors
19 import matplotlib.cm
20
21 from mpl_toolkits.axes_grid1 import make_axes_locatable
22
23 try:
24 __file__
25 TERMINAL = True
26 except:
27 TERMINAL = False
28
29
30 # ------------------------
31 # Methods
32 # ------------------------
33 def scalar_colormap(values, cmap, vmin, vmax):
34 """This method creates a colormap based on values.
35
36 Parameters
37 ----------
38 values : array-like
39 The values to create the corresponding colors
40
41 cmap : str
42 The colormap
43
44 vmin, vmax : float
45 The minimum and maximum possible values
46
47 Returns
48 -------
49 scalar colormap
50 """
51 # Create scalar mappable
52 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
53 mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
54 # Get color map
55 colormap = sns.color_palette([mapper.to_rgba(i) for i in values])
56 # Return
57 return colormap, norm
58
59
60 def scalar_palette(values, cmap, vmin, vmax):
61 """This method creates a colorpalette based on values.
62
63 Parameters
64 ----------
65 values : array-like
66 The values to create the corresponding colors
67
68 cmap : str
69 The colormap
70
71 vmin, vmax : float
72 The minimum and maximum possible values
73
74 Returns
75 -------
76 scalar colormap
77
78 """
79 # Create a matplotlib colormap from name
80 # cmap = sns.light_palette(cmap, reverse=False, as_cmap=True)
81 cmap = sns.color_palette(cmap, as_cmap=True)
82 # Normalize to the range of possible values from df["c"]
83 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
84 # Create a color dictionary (value in c : color from colormap)
85 colors = {}
86 for cval in values:
87 colors.update({cval: cmap(norm(cval))})
88 # Return
89 return colors, norm
90
91
92 def load_shap_file():
93 """Load shap file.
94
95 .. note: The timestep does not indicate time step but matrix
96 index index. Since the matrix index for time steps
97 started in negative t=-T and ended in t=0 the
98 transformation should be taken into account.
99
100 """
101 from pathlib import Path
102 # Load data
103 path = Path('../../datasets/shap/')
104 data = pd.read_csv(path / 'shap.csv')
105 data = data.iloc[:, 1:]
106 data = data.rename(columns={'timestep': 'indice'})
107 data['timestep'] = data.indice - (data.indice.nunique() - 1)
108 return data
109
110
111
112 # -------------------------------------------------------------------
113 # Main
114 # -------------------------------------------------------------------
115 # Configuration
116 cmap_name = 'coolwarm' # colormap name
117 norm_shap = True
118
119 # Load data
120 data = load_shap_file()
121 #data = data[data['sample'] < 100]
122
123 # Show
124 if TERMINAL:
125 print("\nShow:")
126 print(data)
Let’s see how data looks like
130 data.head(10)
Let’s show using sns.stripplot
Warning
This method seems to be quite slow.
Note
y-axis has been ‘normalized’
141 def add_colorbar(fig, cmap, norm):
142 """"""
143 divider = make_axes_locatable(plt.gca())
144 ax_cb = divider.new_horizontal(size="5%", pad=0.05)
145 fig.add_axes(ax_cb)
146 cb1 = matplotlib.colorbar.ColorbarBase(ax_cb,
147 cmap=cmap, norm=norm, orientation='vertical')
148
149
150 # Loop
151 for i, (name, df) in enumerate(data.groupby('features')):
152
153 # Get colormap
154 values = df.feature_values
155 cmap, norm = scalar_palette(values=values,
156 cmap=cmap_name, vmin=values.min(),
157 vmax=values.max())
158
159 # Display
160 fig, ax = plt.subplots()
161 ax = sns.stripplot(x='timestep',
162 y='shap_values',
163 hue='feature_values',
164 palette=cmap,
165 data=df,
166 ax=ax)
167
168 # Format figure
169 plt.title(name)
170 plt.legend([], [], frameon=False)
171
172 if norm_shap:
173 plt.ylim(data.shap_values.min(),
174 data.shap_values.max())
175
176 # Invert x axis (if no negative timesteps)
177 #ax.invert_xaxis()
178
179 # Create colormap (fix for old versions of mpl)
180 cmap = matplotlib.cm.get_cmap(cmap_name)
181
182 # Add colorbar
183 add_colorbar(plt.gcf(), cmap, norm)
184
185 # Show only first N
186 if int(i) > 5:
187 break
188
189 # Show
190 plt.show()
Total running time of the script: ( 0 minutes 20.942 seconds)