.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples/plotly/plot_main12_mpl2ply_shap.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_plotly_plot_main12_mpl2ply_shap.py: 12. MPL2PLY SHAP summary ============================== .. note:: In the latest commit of plotly packages/python/plotly/plotly/matplotlylib/mpltools.py line 368, it still calls is_frame_like() function. There is already an issue tracking this. You may need choose to downgrade Matplotlib if you still want to use mpl_to_plotly() function. .. GENERATED FROM PYTHON SOURCE LINES 10-203 .. image-sg:: /_examples/plotly/images/sphx_glr_plot_main12_mpl2ply_shap_001.png :alt: plot main12 mpl2ply shap :srcset: /_examples/plotly/images/sphx_glr_plot_main12_mpl2ply_shap_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Kernel type: Shap values: .values = array([[ 0.01222555, 1.37204514, -4.45924262], [-0.03614123, -0.05402528, -4.71252587], [-0.04250672, -0.67383818, -4.2542946 ], ..., [ 0.34987633, 0.35588726, 2.19385727], [-0.02678569, -0.06291732, -4.71763066], [-0.04250672, -0.67383818, -4.2542946 ]]) .base_values = array([0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043, 0.35041043]) .data = array([[ 17.99, 10.38, 122.8 ], [ 20.57, 17.77, 132.9 ], [ 19.69, 21.25, 130. ], ..., [ 12.47, 17.31, 80.45], [ 18.49, 17.52, 121.3 ], [ 20.59, 21.24, 137.8 ]]) (500, 3) (500,) (500, 3) '\n# Convert to plotly\nimport plotly.tools as tls\nimport plotly.graph_objs as go\n\n# Get current figure and convert\nfig = tls.mpl_to_plotly(plt.gcf())\n\n# Format\n# Update layout\nfig.update_layout(\n #xaxis_title=\'False Positive Rate\',\n #yaxis_title=\'True Positive Rate\',\n #yaxis=dict(scaleanchor="x", scaleratio=1),\n #xaxis=dict(constrain=\'domain\'),\n width=700, height=350,\n #legend=dict(\n # x=1.0, y=0.0, # x=1, y=1.02\n # orientation="v",\n # font=dict(\n # size=12,\n # color=\'black\'),\n # yanchor="bottom",\n # xanchor="right",\n #),\n font=dict(\n size=15,\n #family="Times New Roman",\n #color="black",\n ),\n title=dict(\n font=dict(\n # family="Times New Roman",\n # color="black"\n )\n ),\n yaxis=dict(\n tickmode=\'array\',\n tickvals=[0, 1, 2],\n ticktext=features[:-1],\n tickfont=dict(size=15)\n ),\n xaxis=dict(\n tickfont=dict(size=15)),\n #margin={\n # \'l\': 0,\n # \'r\': 0,\n # \'b\': 0,\n # \'t\': 0,\n # \'pad\': 4\n #},\n paper_bgcolor=\'rgba(0,0,0,0)\', # transparent\n plot_bgcolor=\'rgba(0,0,0,0)\', # transparent\n template=\'simple_white\'\n)\n\n# Update scatter\nfig.update_traces(marker={\'size\': 10})\n\n# Add vertical lin\nfig.add_vline(x=0.0, line_width=2,\n line_dash="dash", line_color="black") # green\n\n# .. note:: Would it be possible to get the values of\n# cmin, cmax and the tick vals from the shap\n# values? Ideally we do not want to hardcode\n# them.\n\n# Add colorbar\ncolorbar_trace = go.Scatter(\n x=[None], y=[None], mode=\'markers\',\n marker=dict(\n colorscale=\'viridis\',\n showscale=True,\n cmin=-5,\n cmax=5,\n colorbar=dict(thickness=20,\n tickvals=[-5, 5],\n ticktext=[\'Low\', \'High\'],\n outlinewidth=0)),\n hoverinfo=\'none\'\n)\nfig[\'layout\'][\'showlegend\'] = False\nfig.add_trace(colorbar_trace)\n\n# Show\n#fig.show()\nfig\n' | .. code-block:: default :lineno-start: 10 # Generic import numpy as np import pandas as pd import matplotlib.pyplot as plt # Sklearn from sklearn.model_selection import train_test_split from sklearn.datasets import load_iris from sklearn.datasets import load_breast_cancer from sklearn.naive_bayes import GaussianNB from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier # Xgboost from xgboost import XGBClassifier try: __file__ TERMINAL = True except: TERMINAL = False # ---------------------------------------- # Load data # ---------------------------------------- # Seed seed = 0 # Load dataset bunch = load_iris() bunch = load_breast_cancer() # Features features = list(bunch['feature_names']) # Create DataFrame data = pd.DataFrame(data=np.c_[bunch['data'], bunch['target']], columns=features + ['target']) # Create X, y X = data[bunch['feature_names']] y = data['target'] # Filter X = X.iloc[:500, :3] y = y.iloc[:500] # Split dataset X_train, X_test, y_train, y_test = \ train_test_split(X, y, random_state=seed) # ---------------------------------------- # Classifiers # ---------------------------------------- # Define some classifiers gnb = GaussianNB() llr = LogisticRegression() dtc = DecisionTreeClassifier(random_state=seed) rfc = RandomForestClassifier(random_state=seed) xgb = XGBClassifier( min_child_weight=0.005, eta= 0.05, gamma= 0.2, max_depth= 4, n_estimators= 100) # Select one clf = xgb # Fit clf.fit(X_train, y_train) # ---------------------------------------- # Compute shap values # ---------------------------------------- # Import import shap # Get generic explainer explainer = shap.Explainer(clf, X_train) # Show kernel type print("\nKernel type: %s" % type(explainer)) # Get shap values shap_values = explainer(X) # Show shap values print("Shap values:") print(shap_values) print(shap_values.values.shape) print(shap_values.base_values.shape) print(shap_values.data.shape) # Get matplotlib figure plot_summary = shap.summary_plot( \ explainer.shap_values(X_train), X_train, cmap='viridis', show=False) # Show #plt.show() """ # Convert to plotly import plotly.tools as tls import plotly.graph_objs as go # Get current figure and convert fig = tls.mpl_to_plotly(plt.gcf()) # Format # Update layout fig.update_layout( #xaxis_title='False Positive Rate', #yaxis_title='True Positive Rate', #yaxis=dict(scaleanchor="x", scaleratio=1), #xaxis=dict(constrain='domain'), width=700, height=350, #legend=dict( # x=1.0, y=0.0, # x=1, y=1.02 # orientation="v", # font=dict( # size=12, # color='black'), # yanchor="bottom", # xanchor="right", #), font=dict( size=15, #family="Times New Roman", #color="black", ), title=dict( font=dict( # family="Times New Roman", # color="black" ) ), yaxis=dict( tickmode='array', tickvals=[0, 1, 2], ticktext=features[:-1], tickfont=dict(size=15) ), xaxis=dict( tickfont=dict(size=15)), #margin={ # 'l': 0, # 'r': 0, # 'b': 0, # 't': 0, # 'pad': 4 #}, paper_bgcolor='rgba(0,0,0,0)', # transparent plot_bgcolor='rgba(0,0,0,0)', # transparent template='simple_white' ) # Update scatter fig.update_traces(marker={'size': 10}) # Add vertical lin fig.add_vline(x=0.0, line_width=2, line_dash="dash", line_color="black") # green # .. note:: Would it be possible to get the values of # cmin, cmax and the tick vals from the shap # values? Ideally we do not want to hardcode # them. # Add colorbar colorbar_trace = go.Scatter( x=[None], y=[None], mode='markers', marker=dict( colorscale='viridis', showscale=True, cmin=-5, cmax=5, colorbar=dict(thickness=20, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0)), hoverinfo='none' ) fig['layout']['showlegend'] = False fig.add_trace(colorbar_trace) # Show #fig.show() fig """ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.411 seconds) .. _sphx_glr_download__examples_plotly_plot_main12_mpl2ply_shap.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_main12_mpl2ply_shap.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_main12_mpl2ply_shap.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_