Note
Click here to download the full example code
12. MPL2PLY SHAP summary
Note
In the latest commit of plotly packages/python/plotly/plotly/matplotlylib/mpltools.py line 368, it still calls is_frame_like() function. There is already an issue tracking this. You may need choose to downgrade Matplotlib if you still want to use mpl_to_plotly() function.
Out:
Kernel type: <class 'shap.explainers._tree.Tree'>
Shap values:
.values =
array([[ 0.01222555, 1.37204514, -4.45924262],
[-0.03614123, -0.05402528, -4.71252587],
[-0.04250672, -0.67383818, -4.2542946 ],
...,
[ 0.34987633, 0.35588726, 2.19385727],
[-0.02678569, -0.06291732, -4.71763066],
[-0.04250672, -0.67383818, -4.2542946 ]])
.base_values =
array([0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043])
.data =
array([[ 17.99, 10.38, 122.8 ],
[ 20.57, 17.77, 132.9 ],
[ 19.69, 21.25, 130. ],
...,
[ 12.47, 17.31, 80.45],
[ 18.49, 17.52, 121.3 ],
[ 20.59, 21.24, 137.8 ]])
(500, 3)
(500,)
(500, 3)
'\n# Convert to plotly\nimport plotly.tools as tls\nimport plotly.graph_objs as go\n\n# Get current figure and convert\nfig = tls.mpl_to_plotly(plt.gcf())\n\n# Format\n# Update layout\nfig.update_layout(\n #xaxis_title=\'False Positive Rate\',\n #yaxis_title=\'True Positive Rate\',\n #yaxis=dict(scaleanchor="x", scaleratio=1),\n #xaxis=dict(constrain=\'domain\'),\n width=700, height=350,\n #legend=dict(\n # x=1.0, y=0.0, # x=1, y=1.02\n # orientation="v",\n # font=dict(\n # size=12,\n # color=\'black\'),\n # yanchor="bottom",\n # xanchor="right",\n #),\n font=dict(\n size=15,\n #family="Times New Roman",\n #color="black",\n ),\n title=dict(\n font=dict(\n # family="Times New Roman",\n # color="black"\n )\n ),\n yaxis=dict(\n tickmode=\'array\',\n tickvals=[0, 1, 2],\n ticktext=features[:-1],\n tickfont=dict(size=15)\n ),\n xaxis=dict(\n tickfont=dict(size=15)),\n #margin={\n # \'l\': 0,\n # \'r\': 0,\n # \'b\': 0,\n # \'t\': 0,\n # \'pad\': 4\n #},\n paper_bgcolor=\'rgba(0,0,0,0)\', # transparent\n plot_bgcolor=\'rgba(0,0,0,0)\', # transparent\n template=\'simple_white\'\n)\n\n# Update scatter\nfig.update_traces(marker={\'size\': 10})\n\n# Add vertical lin\nfig.add_vline(x=0.0, line_width=2,\n line_dash="dash", line_color="black") # green\n\n# .. note:: Would it be possible to get the values of\n# cmin, cmax and the tick vals from the shap\n# values? Ideally we do not want to hardcode\n# them.\n\n# Add colorbar\ncolorbar_trace = go.Scatter(\n x=[None], y=[None], mode=\'markers\',\n marker=dict(\n colorscale=\'viridis\',\n showscale=True,\n cmin=-5,\n cmax=5,\n colorbar=dict(thickness=20,\n tickvals=[-5, 5],\n ticktext=[\'Low\', \'High\'],\n outlinewidth=0)),\n hoverinfo=\'none\'\n)\nfig[\'layout\'][\'showlegend\'] = False\nfig.add_trace(colorbar_trace)\n\n# Show\n#fig.show()\nfig\n'
10 # Generic
11 import numpy as np
12 import pandas as pd
13 import matplotlib.pyplot as plt
14
15 # Sklearn
16 from sklearn.model_selection import train_test_split
17 from sklearn.datasets import load_iris
18 from sklearn.datasets import load_breast_cancer
19 from sklearn.naive_bayes import GaussianNB
20 from sklearn.linear_model import LogisticRegression
21 from sklearn.tree import DecisionTreeClassifier
22 from sklearn.ensemble import RandomForestClassifier
23
24 # Xgboost
25 from xgboost import XGBClassifier
26
27 try:
28 __file__
29 TERMINAL = True
30 except:
31 TERMINAL = False
32
33
34 # ----------------------------------------
35 # Load data
36 # ----------------------------------------
37 # Seed
38 seed = 0
39
40 # Load dataset
41 bunch = load_iris()
42 bunch = load_breast_cancer()
43
44 # Features
45 features = list(bunch['feature_names'])
46
47 # Create DataFrame
48 data = pd.DataFrame(data=np.c_[bunch['data'],
49 bunch['target']], columns=features + ['target'])
50
51 # Create X, y
52 X = data[bunch['feature_names']]
53 y = data['target']
54
55 # Filter
56 X = X.iloc[:500, :3]
57 y = y.iloc[:500]
58
59 # Split dataset
60 X_train, X_test, y_train, y_test = \
61 train_test_split(X, y, random_state=seed)
62
63
64 # ----------------------------------------
65 # Classifiers
66 # ----------------------------------------
67 # Define some classifiers
68 gnb = GaussianNB()
69 llr = LogisticRegression()
70 dtc = DecisionTreeClassifier(random_state=seed)
71 rfc = RandomForestClassifier(random_state=seed)
72 xgb = XGBClassifier(
73 min_child_weight=0.005,
74 eta= 0.05, gamma= 0.2,
75 max_depth= 4,
76 n_estimators= 100)
77
78 # Select one
79 clf = xgb
80
81 # Fit
82 clf.fit(X_train, y_train)
83
84 # ----------------------------------------
85 # Compute shap values
86 # ----------------------------------------
87 # Import
88 import shap
89
90 # Get generic explainer
91 explainer = shap.Explainer(clf, X_train)
92
93 # Show kernel type
94 print("\nKernel type: %s" % type(explainer))
95
96 # Get shap values
97 shap_values = explainer(X)
98
99 # Show shap values
100 print("Shap values:")
101 print(shap_values)
102 print(shap_values.values.shape)
103 print(shap_values.base_values.shape)
104 print(shap_values.data.shape)
105
106 # Get matplotlib figure
107 plot_summary = shap.summary_plot( \
108 explainer.shap_values(X_train),
109 X_train, cmap='viridis',
110 show=False)
111
112 # Show
113 #plt.show()
114
115 """
116 # Convert to plotly
117 import plotly.tools as tls
118 import plotly.graph_objs as go
119
120 # Get current figure and convert
121 fig = tls.mpl_to_plotly(plt.gcf())
122
123 # Format
124 # Update layout
125 fig.update_layout(
126 #xaxis_title='False Positive Rate',
127 #yaxis_title='True Positive Rate',
128 #yaxis=dict(scaleanchor="x", scaleratio=1),
129 #xaxis=dict(constrain='domain'),
130 width=700, height=350,
131 #legend=dict(
132 # x=1.0, y=0.0, # x=1, y=1.02
133 # orientation="v",
134 # font=dict(
135 # size=12,
136 # color='black'),
137 # yanchor="bottom",
138 # xanchor="right",
139 #),
140 font=dict(
141 size=15,
142 #family="Times New Roman",
143 #color="black",
144 ),
145 title=dict(
146 font=dict(
147 # family="Times New Roman",
148 # color="black"
149 )
150 ),
151 yaxis=dict(
152 tickmode='array',
153 tickvals=[0, 1, 2],
154 ticktext=features[:-1],
155 tickfont=dict(size=15)
156 ),
157 xaxis=dict(
158 tickfont=dict(size=15)),
159 #margin={
160 # 'l': 0,
161 # 'r': 0,
162 # 'b': 0,
163 # 't': 0,
164 # 'pad': 4
165 #},
166 paper_bgcolor='rgba(0,0,0,0)', # transparent
167 plot_bgcolor='rgba(0,0,0,0)', # transparent
168 template='simple_white'
169 )
170
171 # Update scatter
172 fig.update_traces(marker={'size': 10})
173
174 # Add vertical lin
175 fig.add_vline(x=0.0, line_width=2,
176 line_dash="dash", line_color="black") # green
177
178 # .. note:: Would it be possible to get the values of
179 # cmin, cmax and the tick vals from the shap
180 # values? Ideally we do not want to hardcode
181 # them.
182
183 # Add colorbar
184 colorbar_trace = go.Scatter(
185 x=[None], y=[None], mode='markers',
186 marker=dict(
187 colorscale='viridis',
188 showscale=True,
189 cmin=-5,
190 cmax=5,
191 colorbar=dict(thickness=20,
192 tickvals=[-5, 5],
193 ticktext=['Low', 'High'],
194 outlinewidth=0)),
195 hoverinfo='none'
196 )
197 fig['layout']['showlegend'] = False
198 fig.add_trace(colorbar_trace)
199
200 # Show
201 #fig.show()
202 fig
203 """
Total running time of the script: ( 0 minutes 1.411 seconds)