.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples\shap\plot_main05_stripplot.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_shap_plot_main05_stripplot.py: 05a. Custom using stripplot ===================================== This script demonstrates how to create a custom visualization for sequential or time-series SHAP values using ``seaborn.stripplot``. This approach provides a granular, per-feature view of how SHAP values are distributed across different timesteps, offering an alternative to the standard SHAP library plots. The script's workflow focuses on: - **Loading Pre-computed Data:** It ingests a tidy DataFrame of SHAP and feature values, structured for time-series analysis. - **Per-Feature Visualization:*8 It iterates through each feature, generating a dedicated stripplot to isolate its impact over time without the influence of other features. - **Advanced Coloring:** A key feature is the custom implementation of coloring each data point by its original feature value, complete with a color bar, to replicate the rich context provided in native SHAP plots. - *8Plot Customization:** It shows how to control axes, legends, and other plot aesthetics for a polished final visualization. While noted to be slower than other methods, this example is ideal for creating detailed, publication-quality plots that reveal the dynamics of feature contributions over a sequence. .. GENERATED FROM PYTHON SOURCE LINES 29-149 .. code-block:: default :lineno-start: 30 # Libraries import shap import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt import matplotlib as mpl import matplotlib.colorbar import matplotlib.colors import matplotlib.cm from mpl_toolkits.axes_grid1 import make_axes_locatable try: __file__ TERMINAL = True except: TERMINAL = False # ------------------------ # Methods # ------------------------ def scalar_colormap(values, cmap, vmin, vmax): """This method creates a colormap based on values. Parameters ---------- values : array-like The values to create the corresponding colors cmap : str The colormap vmin, vmax : float The minimum and maximum possible values Returns ------- scalar colormap """ # Create scalar mappable norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True) mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) # Get color map colormap = sns.color_palette([mapper.to_rgba(i) for i in values]) # Return return colormap, norm def scalar_palette(values, cmap, vmin, vmax): """This method creates a colorpalette based on values. Parameters ---------- values : array-like The values to create the corresponding colors cmap : str The colormap vmin, vmax : float The minimum and maximum possible values Returns ------- scalar colormap """ # Create a matplotlib colormap from name # cmap = sns.light_palette(cmap, reverse=False, as_cmap=True) cmap = sns.color_palette(cmap, as_cmap=True) # Normalize to the range of possible values from df["c"] norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) # Create a color dictionary (value in c : color from colormap) colors = {} for cval in values: colors.update({cval: cmap(norm(cval))}) # Return return colors, norm def load_shap_file(): """Load shap file. .. note: The timestep does not indicate time step but matrix index index. Since the matrix index for time steps started in negative t=-T and ended in t=0 the transformation should be taken into account. """ from pathlib import Path # Load data path = Path('../../datasets/shap/') data = pd.read_csv(path / 'shap.csv') data = data.iloc[:, 1:] data = data.rename(columns={'timestep': 'indice'}) data['timestep'] = data.indice - (data.indice.nunique() - 1) return data # ------------------------------------------------------------------- # Main # ------------------------------------------------------------------- # Configuration cmap_name = 'coolwarm' # colormap name norm_shap = True # Load data data = load_shap_file() #data = data[data['sample'] < 100] # Show if TERMINAL: print("\nShow:") print(data) .. GENERATED FROM PYTHON SOURCE LINES 150-151 Let's see how data looks like .. GENERATED FROM PYTHON SOURCE LINES 151-154 .. code-block:: default :lineno-start: 151 data.head(10) .. raw:: html
sample indice features feature_values shap_values timestep
0 0 0 Ward Lactate 0.0 0.000652 -6
1 0 0 Ward Glucose 0.0 -0.000596 -6
2 0 0 Ward sO2 0.0 0.000231 -6
3 0 0 White blood cell count, blood 0.0 0.000582 -6
4 0 0 Platelets 0.0 -0.001705 -6
5 0 0 Haemoglobin 0.0 -0.000918 -6
6 0 0 Mean cell volume, blood 0.0 -0.000654 -6
7 0 0 Haematocrit 0.0 -0.000487 -6
8 0 0 Mean cell haemoglobin conc, blood 0.0 0.000090 -6
9 0 0 Mean cell haemoglobin level, blood 0.0 -0.000296 -6


.. GENERATED FROM PYTHON SOURCE LINES 155-161 Let's show using ``sns.stripplot`` .. warning:: This method seems to be quite slow. .. note:: y-axis has been 'normalized' .. GENERATED FROM PYTHON SOURCE LINES 161-212 .. code-block:: default :lineno-start: 162 def add_colorbar(fig, cmap, norm): """""" divider = make_axes_locatable(plt.gca()) ax_cb = divider.new_horizontal(size="5%", pad=0.05) fig.add_axes(ax_cb) cb1 = matplotlib.colorbar.ColorbarBase(ax_cb, cmap=cmap, norm=norm, orientation='vertical') # Loop for i, (name, df) in enumerate(data.groupby('features')): # Get colormap values = df.feature_values cmap, norm = scalar_palette(values=values, cmap=cmap_name, vmin=values.min(), vmax=values.max()) # Display fig, ax = plt.subplots() ax = sns.stripplot(x='timestep', y='shap_values', hue='feature_values', palette=cmap, data=df, ax=ax) # Format figure plt.title(name) plt.legend([], [], frameon=False) if norm_shap: plt.ylim(data.shap_values.min(), data.shap_values.max()) # Invert x axis (if no negative timesteps) #ax.invert_xaxis() # Create colormap (fix for old versions of mpl) cmap = matplotlib.cm.get_cmap(cmap_name) # Add colorbar add_colorbar(plt.gcf(), cmap, norm) # Show only first N if int(i) > 5: break # Show plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_stripplot_001.png :alt: Alanine Transaminase :srcset: /_examples/shap/images/sphx_glr_plot_main05_stripplot_001.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_stripplot_002.png :alt: Albumin :srcset: /_examples/shap/images/sphx_glr_plot_main05_stripplot_002.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_stripplot_003.png :alt: Alkaline Phosphatase :srcset: /_examples/shap/images/sphx_glr_plot_main05_stripplot_003.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_stripplot_004.png :alt: Bilirubin :srcset: /_examples/shap/images/sphx_glr_plot_main05_stripplot_004.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_stripplot_005.png :alt: C-Reactive Protein :srcset: /_examples/shap/images/sphx_glr_plot_main05_stripplot_005.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_stripplot_006.png :alt: Chloride :srcset: /_examples/shap/images/sphx_glr_plot_main05_stripplot_006.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_stripplot_007.png :alt: Creatinine :srcset: /_examples/shap/images/sphx_glr_plot_main05_stripplot_007.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_stripplot.py:201: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_stripplot.py:201: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_stripplot.py:201: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_stripplot.py:201: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_stripplot.py:201: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_stripplot.py:201: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_stripplot.py:201: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_stripplot.py:211: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 8.549 seconds) .. _sphx_glr_download__examples_shap_plot_main05_stripplot.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_main05_stripplot.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_main05_stripplot.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_