12. MPL2PLY SHAP summary

Note

In the latest commit of plotly packages/python/plotly/plotly/matplotlylib/mpltools.py line 368, it still calls is_frame_like() function. There is already an issue tracking this. You may need choose to downgrade Matplotlib if you still want to use mpl_to_plotly() function.

plot main12 mpl2ply shap

Out:

Kernel type: <class 'shap.explainers._tree.Tree'>
Shap values:
.values =
array([[ 0.01222555,  1.37204514, -4.45924262],
       [-0.03614123, -0.05402528, -4.71252587],
       [-0.04250672, -0.67383818, -4.2542946 ],
       ...,
       [ 0.34987633,  0.35588726,  2.19385727],
       [-0.02678569, -0.06291732, -4.71763066],
       [-0.04250672, -0.67383818, -4.2542946 ]])

.base_values =
array([0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043,
       0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043])

.data =
array([[ 17.99,  10.38, 122.8 ],
       [ 20.57,  17.77, 132.9 ],
       [ 19.69,  21.25, 130.  ],
       ...,
       [ 12.47,  17.31,  80.45],
       [ 18.49,  17.52, 121.3 ],
       [ 20.59,  21.24, 137.8 ]])
(500, 3)
(500,)
(500, 3)

'\n# Convert to plotly\nimport plotly.tools as tls\nimport plotly.graph_objs as go\n\n# Get current figure and convert\nfig = tls.mpl_to_plotly(plt.gcf())\n\n# Format\n# Update layout\nfig.update_layout(\n    #xaxis_title=\'False Positive Rate\',\n    #yaxis_title=\'True Positive Rate\',\n    #yaxis=dict(scaleanchor="x", scaleratio=1),\n    #xaxis=dict(constrain=\'domain\'),\n    width=700, height=350,\n    #legend=dict(\n    #    x=1.0, y=0.0,  # x=1, y=1.02\n    #    orientation="v",\n    #    font=dict(\n    #       size=12,\n    #        color=\'black\'),\n    #    yanchor="bottom",\n    #    xanchor="right",\n    #),\n    font=dict(\n        size=15,\n        #family="Times New Roman",\n        #color="black",\n    ),\n    title=dict(\n        font=dict(\n        #    family="Times New Roman",\n        #    color="black"\n        )\n    ),\n    yaxis=dict(\n        tickmode=\'array\',\n        tickvals=[0, 1, 2],\n        ticktext=features[:-1],\n        tickfont=dict(size=15)\n    ),\n    xaxis=dict(\n        tickfont=dict(size=15)),\n    #margin={\n    #    \'l\': 0,\n    #    \'r\': 0,\n    #    \'b\': 0,\n    #    \'t\': 0,\n    #    \'pad\': 4\n    #},\n    paper_bgcolor=\'rgba(0,0,0,0)\',  # transparent\n    plot_bgcolor=\'rgba(0,0,0,0)\',   # transparent\n    template=\'simple_white\'\n)\n\n# Update scatter\nfig.update_traces(marker={\'size\': 10})\n\n# Add vertical lin\nfig.add_vline(x=0.0, line_width=2,\n    line_dash="dash", line_color="black") # green\n\n# .. note:: Would it be possible to get the values of\n#           cmin, cmax and the tick vals from the shap\n#           values? Ideally we do not want to hardcode\n#           them.\n\n# Add colorbar\ncolorbar_trace = go.Scatter(\n    x=[None], y=[None], mode=\'markers\',\n    marker=dict(\n        colorscale=\'viridis\',\n        showscale=True,\n        cmin=-5,\n        cmax=5,\n        colorbar=dict(thickness=20,\n            tickvals=[-5, 5],\n            ticktext=[\'Low\', \'High\'],\n            outlinewidth=0)),\n    hoverinfo=\'none\'\n)\nfig[\'layout\'][\'showlegend\'] = False\nfig.add_trace(colorbar_trace)\n\n# Show\n#fig.show()\nfig\n'

 10 # Generic
 11 import numpy as np
 12 import pandas as pd
 13 import matplotlib.pyplot as plt
 14
 15 # Sklearn
 16 from sklearn.model_selection import train_test_split
 17 from sklearn.datasets import load_iris
 18 from sklearn.datasets import load_breast_cancer
 19 from sklearn.naive_bayes import GaussianNB
 20 from sklearn.linear_model import LogisticRegression
 21 from sklearn.tree import DecisionTreeClassifier
 22 from sklearn.ensemble import RandomForestClassifier
 23
 24 # Xgboost
 25 from xgboost import XGBClassifier
 26
 27 try:
 28     __file__
 29     TERMINAL = True
 30 except:
 31     TERMINAL = False
 32
 33
 34 # ----------------------------------------
 35 # Load data
 36 # ----------------------------------------
 37 # Seed
 38 seed = 0
 39
 40 # Load dataset
 41 bunch = load_iris()
 42 bunch = load_breast_cancer()
 43
 44 # Features
 45 features = list(bunch['feature_names'])
 46
 47 # Create DataFrame
 48 data = pd.DataFrame(data=np.c_[bunch['data'],
 49     bunch['target']], columns=features + ['target'])
 50
 51 # Create X, y
 52 X = data[bunch['feature_names']]
 53 y = data['target']
 54
 55 # Filter
 56 X = X.iloc[:500, :3]
 57 y = y.iloc[:500]
 58
 59 # Split dataset
 60 X_train, X_test, y_train, y_test = \
 61     train_test_split(X, y, random_state=seed)
 62
 63
 64 # ----------------------------------------
 65 # Classifiers
 66 # ----------------------------------------
 67 # Define some classifiers
 68 gnb = GaussianNB()
 69 llr = LogisticRegression()
 70 dtc = DecisionTreeClassifier(random_state=seed)
 71 rfc = RandomForestClassifier(random_state=seed)
 72 xgb = XGBClassifier(
 73     min_child_weight=0.005,
 74     eta= 0.05, gamma= 0.2,
 75     max_depth= 4,
 76     n_estimators= 100)
 77
 78 # Select one
 79 clf = xgb
 80
 81 # Fit
 82 clf.fit(X_train, y_train)
 83
 84 # ----------------------------------------
 85 # Compute shap values
 86 # ----------------------------------------
 87 # Import
 88 import shap
 89
 90 # Get generic explainer
 91 explainer = shap.Explainer(clf, X_train)
 92
 93 # Show kernel type
 94 print("\nKernel type: %s" % type(explainer))
 95
 96 # Get shap values
 97 shap_values = explainer(X)
 98
 99 # Show shap values
100 print("Shap values:")
101 print(shap_values)
102 print(shap_values.values.shape)
103 print(shap_values.base_values.shape)
104 print(shap_values.data.shape)
105
106 # Get matplotlib figure
107 plot_summary = shap.summary_plot( \
108     explainer.shap_values(X_train),
109     X_train, cmap='viridis',
110     show=False)
111
112 # Show
113 #plt.show()
114
115 """
116 # Convert to plotly
117 import plotly.tools as tls
118 import plotly.graph_objs as go
119
120 # Get current figure and convert
121 fig = tls.mpl_to_plotly(plt.gcf())
122
123 # Format
124 # Update layout
125 fig.update_layout(
126     #xaxis_title='False Positive Rate',
127     #yaxis_title='True Positive Rate',
128     #yaxis=dict(scaleanchor="x", scaleratio=1),
129     #xaxis=dict(constrain='domain'),
130     width=700, height=350,
131     #legend=dict(
132     #    x=1.0, y=0.0,  # x=1, y=1.02
133     #    orientation="v",
134     #    font=dict(
135     #       size=12,
136     #        color='black'),
137     #    yanchor="bottom",
138     #    xanchor="right",
139     #),
140     font=dict(
141         size=15,
142         #family="Times New Roman",
143         #color="black",
144     ),
145     title=dict(
146         font=dict(
147         #    family="Times New Roman",
148         #    color="black"
149         )
150     ),
151     yaxis=dict(
152         tickmode='array',
153         tickvals=[0, 1, 2],
154         ticktext=features[:-1],
155         tickfont=dict(size=15)
156     ),
157     xaxis=dict(
158         tickfont=dict(size=15)),
159     #margin={
160     #    'l': 0,
161     #    'r': 0,
162     #    'b': 0,
163     #    't': 0,
164     #    'pad': 4
165     #},
166     paper_bgcolor='rgba(0,0,0,0)',  # transparent
167     plot_bgcolor='rgba(0,0,0,0)',   # transparent
168     template='simple_white'
169 )
170
171 # Update scatter
172 fig.update_traces(marker={'size': 10})
173
174 # Add vertical lin
175 fig.add_vline(x=0.0, line_width=2,
176     line_dash="dash", line_color="black") # green
177
178 # .. note:: Would it be possible to get the values of
179 #           cmin, cmax and the tick vals from the shap
180 #           values? Ideally we do not want to hardcode
181 #           them.
182
183 # Add colorbar
184 colorbar_trace = go.Scatter(
185     x=[None], y=[None], mode='markers',
186     marker=dict(
187         colorscale='viridis',
188         showscale=True,
189         cmin=-5,
190         cmax=5,
191         colorbar=dict(thickness=20,
192             tickvals=[-5, 5],
193             ticktext=['Low', 'High'],
194             outlinewidth=0)),
195     hoverinfo='none'
196 )
197 fig['layout']['showlegend'] = False
198 fig.add_trace(colorbar_trace)
199
200 # Show
201 #fig.show()
202 fig
203 """

Total running time of the script: ( 0 minutes 1.411 seconds)

Gallery generated by Sphinx-Gallery