05. Swarmplot

  6 # Libraries
  7 import shap
  8 import numpy as np
  9 import pandas as pd
 10 import seaborn as sns
 11
 12 import matplotlib.pyplot as plt
 13 import matplotlib as mpl
 14 import matplotlib.colorbar
 15 import matplotlib.colors
 16 import matplotlib.cm
 17
 18 from mpl_toolkits.axes_grid1 import make_axes_locatable
 19
 20 try:
 21     __file__
 22     TERMINAL = True
 23 except:
 24     TERMINAL = False
 25
 26
 27 # ------------------------
 28 # Methods
 29 # ------------------------
 30 def scalar_colormap(values, cmap, vmin, vmax):
 31     """This method creates a colormap based on values.
 32
 33     Parameters
 34     ----------
 35     values : array-like
 36     The values to create the corresponding colors
 37
 38     cmap : str
 39     The colormap
 40
 41     vmin, vmax : float
 42     The minimum and maximum possible values
 43
 44     Returns
 45     -------
 46     scalar colormap
 47     """
 48     # Create scalar mappable
 49     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
 50     mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
 51     # Get color map
 52     colormap = sns.color_palette([mapper.to_rgba(i) for i in values])
 53     # Return
 54     return colormap, norm
 55
 56
 57 def scalar_palette(values, cmap, vmin, vmax):
 58     """This method creates a colorpalette based on values.
 59
 60     Parameters
 61     ----------
 62     values : array-like
 63     The values to create the corresponding colors
 64
 65     cmap : str
 66     The colormap
 67
 68     vmin, vmax : float
 69     The minimum and maximum possible values
 70
 71     Returns
 72     -------
 73     scalar colormap
 74
 75     """
 76     # Create a matplotlib colormap from name
 77     # cmap = sns.light_palette(cmap, reverse=False, as_cmap=True)
 78     cmap = sns.color_palette(cmap, as_cmap=True)
 79     # Normalize to the range of possible values from df["c"]
 80     norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
 81     # Create a color dictionary (value in c : color from colormap)
 82     colors = {}
 83     for cval in values:
 84         colors.update({cval: cmap(norm(cval))})
 85     # Return
 86     return colors, norm
 87
 88 def load_shap_file():
 89     """Load shap file.
 90
 91     .. note: The timestep does not indicate time step but matrix
 92              index index. Since the matrix index for time steps
 93              started in negative t=-T and ended in t=0 the
 94              transformation should be taken into account.
 95
 96     """
 97     from pathlib import Path
 98     # Load data
 99     path = Path('../../datasets/shap/')
100     data = pd.read_csv(path / 'shap.csv')
101     data = data.iloc[:, 1:]
102     data = data.rename(columns={'timestep': 'indice'})
103     data['timestep'] = data.indice - (data.indice.nunique() - 1)
104     return data
105
106
107
108 # ------------------------------------------------------
109 #                       Main
110 # ------------------------------------------------------
111 # Configuration
112 cmap_name = 'coolwarm' # colormap name
113 norm_shap = True
114
115 # Load data
116 data = load_shap_file()
117
118 # Filter so that it is less computationally expensive
119 data = data[data['sample'] < 100]
120
121 # Show
122 if TERMINAL:
123     print("\nShow:")
124     print(data)

Let’s see how data looks like

128 data.head(10)
sample indice features feature_values shap_values timestep
0 0 0 Ward Lactate 0.0 0.000652 -6
1 0 0 Ward Glucose 0.0 -0.000596 -6
2 0 0 Ward sO2 0.0 0.000231 -6
3 0 0 White blood cell count, blood 0.0 0.000582 -6
4 0 0 Platelets 0.0 -0.001705 -6
5 0 0 Haemoglobin 0.0 -0.000918 -6
6 0 0 Mean cell volume, blood 0.0 -0.000654 -6
7 0 0 Haematocrit 0.0 -0.000487 -6
8 0 0 Mean cell haemoglobin conc, blood 0.0 0.000090 -6
9 0 0 Mean cell haemoglobin level, blood 0.0 -0.000296 -6


Display using sns.swarmplot

Warning

This method seems to be quite slow.

Note

y-axis has been ‘normalized’

139 def add_colorbar(fig, cmap, norm):
140     """"""
141     divider = make_axes_locatable(plt.gca())
142     ax_cb = divider.new_horizontal(size="5%", pad=0.05)
143     fig.add_axes(ax_cb)
144     cb1 = matplotlib.colorbar.ColorbarBase(ax_cb,
145          cmap=cmap, norm=norm, orientation='vertical')
146
147
148 # Loop
149 for i, (name, df) in enumerate(data.groupby('features')):
150
151     # Get colormap
152     values = df.feature_values
153     cmap, norm = scalar_palette(values=values,
154         cmap=cmap_name, vmin=values.min(),
155         vmax=values.max())
156
157     # Display
158     fig, ax = plt.subplots()
159     ax = sns.swarmplot(x='timestep',
160                        y='shap_values',
161                        hue='feature_values',
162                        palette=cmap,
163                        data=df,
164                        size=2,
165                        ax=ax)
166
167     # Format figure
168     plt.title(name)
169     plt.legend([], [], frameon=False)
170
171     if norm_shap:
172         plt.ylim(data.shap_values.min(),
173                  data.shap_values.max())
174
175     # Invert x axis (if no negative timesteps)
176     #ax.invert_xaxis()
177
178     # Create colormap (fix for old versions of mpl)
179     cmap = matplotlib.cm.get_cmap(cmap_name)
180
181     # Add colorbar
182     add_colorbar(plt.gcf(), cmap, norm)
183
184     # Show only first N
185     if int(i) > 5:
186         break
187
188 # Show
189 plt.show()
  • Alanine Transaminase
  • Albumin
  • Alkaline Phosphatase
  • Bilirubin
  • C-Reactive Protein
  • Chloride
  • Creatinine

Total running time of the script: ( 0 minutes 13.999 seconds)

Gallery generated by Sphinx-Gallery