Note
Click here to download the full example code
05. Swarmplot
6 # Libraries
7 import shap
8 import numpy as np
9 import pandas as pd
10 import seaborn as sns
11
12 import matplotlib.pyplot as plt
13 import matplotlib as mpl
14 import matplotlib.colorbar
15 import matplotlib.colors
16 import matplotlib.cm
17
18 from mpl_toolkits.axes_grid1 import make_axes_locatable
19
20 try:
21 __file__
22 TERMINAL = True
23 except:
24 TERMINAL = False
25
26
27 # ------------------------
28 # Methods
29 # ------------------------
30 def scalar_colormap(values, cmap, vmin, vmax):
31 """This method creates a colormap based on values.
32
33 Parameters
34 ----------
35 values : array-like
36 The values to create the corresponding colors
37
38 cmap : str
39 The colormap
40
41 vmin, vmax : float
42 The minimum and maximum possible values
43
44 Returns
45 -------
46 scalar colormap
47 """
48 # Create scalar mappable
49 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
50 mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
51 # Get color map
52 colormap = sns.color_palette([mapper.to_rgba(i) for i in values])
53 # Return
54 return colormap, norm
55
56
57 def scalar_palette(values, cmap, vmin, vmax):
58 """This method creates a colorpalette based on values.
59
60 Parameters
61 ----------
62 values : array-like
63 The values to create the corresponding colors
64
65 cmap : str
66 The colormap
67
68 vmin, vmax : float
69 The minimum and maximum possible values
70
71 Returns
72 -------
73 scalar colormap
74
75 """
76 # Create a matplotlib colormap from name
77 # cmap = sns.light_palette(cmap, reverse=False, as_cmap=True)
78 cmap = sns.color_palette(cmap, as_cmap=True)
79 # Normalize to the range of possible values from df["c"]
80 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
81 # Create a color dictionary (value in c : color from colormap)
82 colors = {}
83 for cval in values:
84 colors.update({cval: cmap(norm(cval))})
85 # Return
86 return colors, norm
87
88 def load_shap_file():
89 """Load shap file.
90
91 .. note: The timestep does not indicate time step but matrix
92 index index. Since the matrix index for time steps
93 started in negative t=-T and ended in t=0 the
94 transformation should be taken into account.
95
96 """
97 from pathlib import Path
98 # Load data
99 path = Path('../../datasets/shap/')
100 data = pd.read_csv(path / 'shap.csv')
101 data = data.iloc[:, 1:]
102 data = data.rename(columns={'timestep': 'indice'})
103 data['timestep'] = data.indice - (data.indice.nunique() - 1)
104 return data
105
106
107
108 # ------------------------------------------------------
109 # Main
110 # ------------------------------------------------------
111 # Configuration
112 cmap_name = 'coolwarm' # colormap name
113 norm_shap = True
114
115 # Load data
116 data = load_shap_file()
117
118 # Filter so that it is less computationally expensive
119 data = data[data['sample'] < 100]
120
121 # Show
122 if TERMINAL:
123 print("\nShow:")
124 print(data)
Let’s see how data looks like
128 data.head(10)
Display using sns.swarmplot
Warning
This method seems to be quite slow.
Note
y-axis has been ‘normalized’
139 def add_colorbar(fig, cmap, norm):
140 """"""
141 divider = make_axes_locatable(plt.gca())
142 ax_cb = divider.new_horizontal(size="5%", pad=0.05)
143 fig.add_axes(ax_cb)
144 cb1 = matplotlib.colorbar.ColorbarBase(ax_cb,
145 cmap=cmap, norm=norm, orientation='vertical')
146
147
148 # Loop
149 for i, (name, df) in enumerate(data.groupby('features')):
150
151 # Get colormap
152 values = df.feature_values
153 cmap, norm = scalar_palette(values=values,
154 cmap=cmap_name, vmin=values.min(),
155 vmax=values.max())
156
157 # Display
158 fig, ax = plt.subplots()
159 ax = sns.swarmplot(x='timestep',
160 y='shap_values',
161 hue='feature_values',
162 palette=cmap,
163 data=df,
164 size=2,
165 ax=ax)
166
167 # Format figure
168 plt.title(name)
169 plt.legend([], [], frameon=False)
170
171 if norm_shap:
172 plt.ylim(data.shap_values.min(),
173 data.shap_values.max())
174
175 # Invert x axis (if no negative timesteps)
176 #ax.invert_xaxis()
177
178 # Create colormap (fix for old versions of mpl)
179 cmap = matplotlib.cm.get_cmap(cmap_name)
180
181 # Add colorbar
182 add_colorbar(plt.gcf(), cmap, norm)
183
184 # Show only first N
185 if int(i) > 5:
186 break
187
188 # Show
189 plt.show()
Total running time of the script: ( 0 minutes 13.999 seconds)