05. Stripplot

Warning

This method is quite slow.

  9 # Libraries
 10 import shap
 11 import numpy as np
 12 import pandas as pd
 13 import seaborn as sns
 14
 15 import matplotlib.pyplot as plt
 16 import matplotlib as mpl
 17 import matplotlib.colorbar
 18 import matplotlib.colors
 19 import matplotlib.cm
 20
 21 from mpl_toolkits.axes_grid1 import make_axes_locatable
 22
 23 try:
 24     __file__
 25     TERMINAL = True
 26 except:
 27     TERMINAL = False
 28
 29
 30 # ------------------------
 31 # Methods
 32 # ------------------------
 33 def scalar_colormap(values, cmap, vmin, vmax):
 34     """This method creates a colormap based on values.
 35
 36     Parameters
 37     ----------
 38     values : array-like
 39     The values to create the corresponding colors
 40
 41     cmap : str
 42     The colormap
 43
 44     vmin, vmax : float
 45     The minimum and maximum possible values
 46
 47     Returns
 48     -------
 49     scalar colormap
 50     """
 51     # Create scalar mappable
 52     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
 53     mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
 54     # Get color map
 55     colormap = sns.color_palette([mapper.to_rgba(i) for i in values])
 56     # Return
 57     return colormap, norm
 58
 59
 60 def scalar_palette(values, cmap, vmin, vmax):
 61     """This method creates a colorpalette based on values.
 62
 63     Parameters
 64     ----------
 65     values : array-like
 66     The values to create the corresponding colors
 67
 68     cmap : str
 69     The colormap
 70
 71     vmin, vmax : float
 72     The minimum and maximum possible values
 73
 74     Returns
 75     -------
 76     scalar colormap
 77
 78     """
 79     # Create a matplotlib colormap from name
 80     # cmap = sns.light_palette(cmap, reverse=False, as_cmap=True)
 81     cmap = sns.color_palette(cmap, as_cmap=True)
 82     # Normalize to the range of possible values from df["c"]
 83     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
 84     # Create a color dictionary (value in c : color from colormap)
 85     colors = {}
 86     for cval in values:
 87         colors.update({cval: cmap(norm(cval))})
 88     # Return
 89     return colors, norm
 90
 91
 92 def load_shap_file():
 93     """Load shap file.
 94
 95     .. note: The timestep does not indicate time step but matrix
 96              index index. Since the matrix index for time steps
 97              started in negative t=-T and ended in t=0 the
 98              transformation should be taken into account.
 99
100     """
101     from pathlib import Path
102     # Load data
103     path = Path('../../datasets/shap/')
104     data = pd.read_csv(path / 'shap.csv')
105     data = data.iloc[:, 1:]
106     data = data.rename(columns={'timestep': 'indice'})
107     data['timestep'] = data.indice - (data.indice.nunique() - 1)
108     return data
109
110
111
112 # -------------------------------------------------------------------
113 #                              Main
114 # -------------------------------------------------------------------
115 # Configuration
116 cmap_name = 'coolwarm' # colormap name
117 norm_shap = True
118
119 # Load data
120 data = load_shap_file()
121 #data = data[data['sample'] < 100]
122
123 # Show
124 if TERMINAL:
125     print("\nShow:")
126     print(data)

Let’s see how data looks like

130 data.head(10)
sample indice features feature_values shap_values timestep
0 0 0 Ward Lactate 0.0 0.000652 -6
1 0 0 Ward Glucose 0.0 -0.000596 -6
2 0 0 Ward sO2 0.0 0.000231 -6
3 0 0 White blood cell count, blood 0.0 0.000582 -6
4 0 0 Platelets 0.0 -0.001705 -6
5 0 0 Haemoglobin 0.0 -0.000918 -6
6 0 0 Mean cell volume, blood 0.0 -0.000654 -6
7 0 0 Haematocrit 0.0 -0.000487 -6
8 0 0 Mean cell haemoglobin conc, blood 0.0 0.000090 -6
9 0 0 Mean cell haemoglobin level, blood 0.0 -0.000296 -6


Let’s show using sns.stripplot

Warning

This method seems to be quite slow.

Note

y-axis has been ‘normalized’

141 def add_colorbar(fig, cmap, norm):
142     """"""
143     divider = make_axes_locatable(plt.gca())
144     ax_cb = divider.new_horizontal(size="5%", pad=0.05)
145     fig.add_axes(ax_cb)
146     cb1 = matplotlib.colorbar.ColorbarBase(ax_cb,
147          cmap=cmap, norm=norm, orientation='vertical')
148
149
150 # Loop
151 for i, (name, df) in enumerate(data.groupby('features')):
152
153     # Get colormap
154     values = df.feature_values
155     cmap, norm = scalar_palette(values=values,
156         cmap=cmap_name, vmin=values.min(),
157         vmax=values.max())
158
159     # Display
160     fig, ax = plt.subplots()
161     ax = sns.stripplot(x='timestep',
162                        y='shap_values',
163                        hue='feature_values',
164                        palette=cmap,
165                        data=df,
166                        ax=ax)
167
168     # Format figure
169     plt.title(name)
170     plt.legend([], [], frameon=False)
171
172     if norm_shap:
173         plt.ylim(data.shap_values.min(),
174                  data.shap_values.max())
175
176     # Invert x axis (if no negative timesteps)
177     #ax.invert_xaxis()
178
179     # Create colormap (fix for old versions of mpl)
180     cmap = matplotlib.cm.get_cmap(cmap_name)
181
182     # Add colorbar
183     add_colorbar(plt.gcf(), cmap, norm)
184
185     # Show only first N
186     if int(i) > 5:
187         break
188
189 # Show
190 plt.show()
  • Alanine Transaminase
  • Albumin
  • Alkaline Phosphatase
  • Bilirubin
  • C-Reactive Protein
  • Chloride
  • Creatinine

Total running time of the script: ( 0 minutes 20.942 seconds)

Gallery generated by Sphinx-Gallery