Note
Click here to download the full example code
05. Sepsis shap by timestep (boxplot)
This script visualizes SHAP (SHapley Additive exPlanations) values to analyze feature importance across different time periods. It loads a pre-computed shap.csv file obtained for sepsis patients and a LSTM predictive model and uses Plotly to generate a comprehensive boxplot.
The resulting plot displays the distribution of SHAP values on the y-axis for each feature at every timestep on the x-axis. This makes it easy to quickly identify which features have the most significant impact on the model’s predictions (those with SHAP values of higher magnitude) and to see how that influence changes over time.
18 # Libraries
19 import pandas as pd
20 import numpy as np
21 import matplotlib as mpl
22 import plotly.express as px
23
24 from plotly.io import show
25 from plotly.colors import n_colors
26 from plotly.express.colors import sample_colorscale
27
28 # See https://matplotlib.org/devdocs/users/explain/customizing.html
29 mpl.rcParams['axes.titlesize'] = 8
30 mpl.rcParams['axes.labelsize'] = 8
31 mpl.rcParams['xtick.labelsize'] = 8
32 mpl.rcParams['ytick.labelsize'] = 8
33 mpl.rcParams['legend.fontsize'] = 7
34 mpl.rcParams['legend.handlelength'] = 1
35 mpl.rcParams['legend.handleheight'] = 1
36 mpl.rcParams['legend.loc'] = 'upper left'
37
38 try:
39 __file__
40 TERMINAL = True
41 except:
42 TERMINAL = False
43
44
45 # Load data
46 data = pd.read_csv('../../datasets/shap/shap.csv')
47
48 # Show
49 if TERMINAL:
50 print("\nData:")
51 print(data)
52 data
53
54 # Number of colors
55 N = data.features.nunique()
56
57 # see https://plotly.com/python/builtin-colorscales/#discrete-color-sequences
58 # see https://plotly.github.io/plotly.py-docs/generated/plotly.express.box.html
59
60 # generate an array of rainbow colors by fixing the saturation and lightness of the
61 # HSL representation of colour and marching around the hue. Plotly accepts any CSS
62 # color format, see e.g. http://www.w3schools.com/cssref/css_colors_legal.asp.
63 c0 = ['hsl('+str(h)+',50%'+',50%)'
64 for h in np.linspace(0, 360, N)]
65
66 # More colors
67 x = np.linspace(0, 1, N)
68 c1 = sample_colorscale('viridis', list(x))
69 c2 = sample_colorscale('RdBu', list(x))
70 c3 = sample_colorscale('Jet', list(x))
71 c4 = sample_colorscale('Agsunset', list(x))
72
73 # .. note:: Remove width and size if running locally.
74
75 # Boxplot
76 fig = px.box(data, x='timestep', y='shap_values',
77 color='features', color_discrete_sequence=c4,
78 points='outliers', width=750, height=900)
79
80 # .. note:: If using widescreen, commenting the legend section
81 # will automatically generate a vertical legend with
82 # scrolling if needed. For display purposes in the
83 # docs we have included the legend on top.
84
85 # Update layout
86 fig.update_layout(
87 #margin={
88 # 'l': 0,
89 # 'r': 0,
90 # 'b': 0,
91 # 't': 0,
92 # 'pad': 4
93 #},
94 legend=dict(
95 orientation="h",
96 entrywidth=140,
97 yanchor="bottom",
98 y=1.02,
99 xanchor="right",
100 x=1,
101 #font=dict(
102 # family="Courier",
103 # size=7,
104 # #color="black"
105 #),
106 ),
107 paper_bgcolor='rgba(0,0,0,0)', # transparent
108 plot_bgcolor='rgba(0,0,0,0)' # transparent
109 )
110
111 # Update xaxis
112 fig.update_xaxes(
113 mirror=False,
114 ticks='outside',
115 showline=False,
116 linecolor='black',
117 gridcolor='lightgrey'
118 )
119
120 # Update yaxis
121 fig.update_yaxes(
122 mirror=False,
123 ticks='outside',
124 showline=False,
125 linecolor='black',
126 gridcolor='lightgrey'
127 )
128
129 # Show
130 show(fig)
Total running time of the script: ( 0 minutes 3.995 seconds)