05b. Custom using summary_plot

This script demonstrates how to adapt the standard shap.summary_plot for visualizing the complex, three-dimensional SHAP values (samples, timesteps, features) generated by sequential models. It provides a powerful strategy for interpreting feature importance both at specific points in time and across an entire sequence.

The script showcases a two-pronged visualization approach:

  • Data Reshaping: It begins by pivoting a tidy DataFrame into the wide-format matrices for SHAP values and feature values that are required by the plotting function.

  • Per-Timestep Analysis: It first iterates through each timestep, creating a separate summary plot that reveals the importance of all features at that single moment in the sequence.

  • Per-Feature Analysis: It then iterates through each feature, generating a summary plot that visualizes how the importance of that single feature evolves across all timesteps.

This example is essential for effectively using SHAP’s most common plot to uncover the temporal dynamics of feature contributions in time-series and sequential data models.

29 # Libraries
30 import shap
31 import pandas as pd
32
33 import matplotlib.pyplot as plt
34
35
36 try:
37     __file__
38     TERMINAL = True
39 except:
40     TERMINAL = False
41
42
43 # ------------------------
44 # Methods
45 # ------------------------
46 def load_shap_file():
47     """Load shap file.
48
49     .. note: The timestep does not indicate time step but matrix
50              index index. Since the matrix index for time steps
51              started in negative t=-T and ended in t=0 the
52              transformation should be taken into account.
53
54     """
55     from pathlib import Path
56     # Load data
57     path = Path('../../datasets/shap/')
58     data = pd.read_csv(path / 'shap.csv')
59     data = data.iloc[:, 1:]
60     data = data.rename(columns={'timestep': 'indice'})
61     data['timestep'] = data.indice - (data.indice.nunique() - 1)
62     return data
63
64
65 # -----------------------------------------------------
66 #                       Main
67 # -----------------------------------------------------
68 # Load data
69 # data = create_random_shap(10, 6, 4)
70 data = load_shap_file()
71 #data = data[data['sample'] < 100]
72
73 shap_values = pd.pivot_table(data,
74                              values='shap_values',
75                              index=['sample', 'timestep'],
76                              columns=['features'])
77
78 feature_values = pd.pivot_table(data,
79                                 values='feature_values',
80                                 index=['sample', 'timestep'],
81                                 columns=['features'])
82
83 # Show
84 if TERMINAL:
85     print("\nShow:")
86     print(data)
87     print(shap_values)
88     print(feature_values)

Let’s see how data looks like

92 data.head(10)
sample indice features feature_values shap_values timestep
0 0 0 Ward Lactate 0.0 0.000652 -6
1 0 0 Ward Glucose 0.0 -0.000596 -6
2 0 0 Ward sO2 0.0 0.000231 -6
3 0 0 White blood cell count, blood 0.0 0.000582 -6
4 0 0 Platelets 0.0 -0.001705 -6
5 0 0 Haemoglobin 0.0 -0.000918 -6
6 0 0 Mean cell volume, blood 0.0 -0.000654 -6
7 0 0 Haematocrit 0.0 -0.000487 -6
8 0 0 Mean cell haemoglobin conc, blood 0.0 0.000090 -6
9 0 0 Mean cell haemoglobin level, blood 0.0 -0.000296 -6


Let’s see how shap_values looks like

96 shap_values.iloc[:10, :5]
features Alanine Transaminase Albumin Alkaline Phosphatase Bilirubin C-Reactive Protein
sample timestep
0 -6 -0.001809 0.000411 0.000486 0.000500 0.010186
-5 -0.001363 0.000563 0.000803 -0.000133 0.005363
-4 0.001180 0.000101 0.000859 -0.001680 -0.016017
-3 0.004938 -0.001043 0.000570 -0.003175 -0.044723
-2 0.006206 -0.001760 0.000382 -0.003976 -0.062485
-1 -0.001391 -0.004886 0.002457 0.010031 0.056280
0 0.003583 0.023502 0.000534 0.001672 -0.010238
1 -6 0.000325 -0.000812 -0.000210 -0.000157 0.000971
-5 0.000247 -0.002281 -0.000301 -0.000036 -0.000035
-4 -0.000316 -0.000034 -0.000307 0.000464 -0.009348


Let’s see how feature_values looks like

100 feature_values.iloc[:10, :5]
features Alanine Transaminase Albumin Alkaline Phosphatase Bilirubin C-Reactive Protein
sample timestep
0 -6 0.000000 0.000000 0.000000 0.000000 0.000000
-5 0.000000 0.000000 0.000000 0.000000 0.000000
-4 0.000000 0.000000 0.000000 0.000000 0.000000
-3 0.000000 0.000000 0.000000 0.000000 0.000000
-2 0.000000 0.000000 0.000000 0.000000 0.000000
-1 0.000000 0.000000 0.000000 0.000000 0.000000
0 -0.982956 0.237113 -0.956016 -0.982152 -0.726284
1 -6 -0.994370 -0.587629 -0.956533 -0.988451 -0.398008
-5 -0.993445 -0.587629 -0.954463 -0.990551 -0.190805
-4 -0.994370 -0.628866 -0.963260 -0.990551 -0.307893


Display using shap.summary_plot

105 #
106 # The first option is to use the ``shap`` library to plot the results.
107
108 # Let's define/extract some useful variables.
109 N = 10  # max loops filter
110 TIMESTEPS = len(shap_values.index.unique(level='timestep'))  # number of timesteps
111 SAMPLES = len(shap_values.index.unique(level='sample'))  # number of samples
112
113 shap_min = data.shap_values.min()
114 shap_max = data.shap_values.max()

Now, let’s display the shap values for all features in each timestep.

121 # For each timestep (visualise all features)
122 steps = shap_values.index.get_level_values('timestep').unique()
123 for i, step in enumerate(steps):
124     # Get interesting indexes
125     indice = shap_values.index.get_level_values('timestep') == step
126
127     # Create auxiliary matrices
128     shap_aux = shap_values.iloc[indice]
129     feat_aux = feature_values.iloc[indice]
130
131     # Display
132     plt.figure()
133     plt.title("Timestep: %s" % step)
134     shap.summary_plot(shap_aux.to_numpy(), feat_aux, show=False)
135     plt.xlim(shap_min, shap_max)
  • Timestep: -6
  • Timestep: -5
  • Timestep: -4
  • Timestep: -3
  • Timestep: -2
  • Timestep: -1
  • Timestep: 0

Now, let’s display the shap values for all timesteps of each feature.

141 # For each feature (visualise all time-steps)
142 for i, f in enumerate(shap_values.columns[:N]):
143     # Show
144     # print('%2d. %s' % (i, f))
145
146     # Create auxiliary matrices (select feature and reshape)
147     shap_aux = shap_values.iloc[:, i] \
148         .to_numpy().reshape(-1, TIMESTEPS)
149     feat_aux = feature_values.iloc[:, i] \
150         .to_numpy().reshape(-1, TIMESTEPS)
151     feat_aux = pd.DataFrame(feat_aux,
152         columns=['timestep %s' % j for j in range(-TIMESTEPS+1, 1)]
153         )
154
155     # Show
156     plt.figure()
157     plt.title("Feature: %s" % f)
158     shap.summary_plot(shap_aux, feat_aux, sort=False, show=False, plot_type='violin')
159     plt.xlim(shap_min, shap_max)
160     plt.gca().invert_yaxis()
161
162 # Show
163 plt.show()
  • Feature: Alanine Transaminase
  • Feature: Albumin
  • Feature: Alkaline Phosphatase
  • Feature: Bilirubin
  • Feature: C-Reactive Protein
  • Feature: Chloride
  • Feature: Creatinine
  • Feature: D-Dimer
  • Feature: Eosinophils
  • Feature: Ferritin

Out:

C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05_summaryplot.py:163: UserWarning:

FigureCanvasAgg is non-interactive, and thus cannot be shown

Total running time of the script: ( 0 minutes 8.284 seconds)

Gallery generated by Sphinx-Gallery