.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples/shap/plot_main01.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_shap_plot_main01.py: 01. Basic example ================== .. GENERATED FROM PYTHON SOURCE LINES 8-200 .. image-sg:: /_examples/shap/images/sphx_glr_plot_main01_001.png :alt: plot main01 :srcset: /_examples/shap/images/sphx_glr_plot_main01_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Kernel type: .values = array([[ 0.01, 1.37, -4.46], [-0.04, -0.05, -4.71], [-0.04, -0.67, -4.25], ..., [ 0.35, 0.36, 2.19], [-0.03, -0.06, -4.72], [-0.04, -0.67, -4.25]]) .base_values = array([0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35]) .data = array([[ 17.99, 10.38, 122.8 ], [ 20.57, 17.77, 132.9 ], [ 19.69, 21.25, 130. ], ..., [ 12.47, 17.31, 80.45], [ 18.49, 17.52, 121.3 ], [ 20.59, 21.24, 137.8 ]]) shap_values (shape): (500, 3) [[-0.09 -0.58 -4.2 ] [-0.57 0.93 3.57] [ 0.33 2.81 -2.97] ... [ 0.78 -0.04 -2.36] [ 0.25 -0.04 0.63] [ 0.22 1.57 -3.67]] None '\nprint(sv)\n#sns.swarmplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#sns.stripplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#plt.show()\nimport sys\nsys.exit()\n#sns.swarmplot(x=)\n\nimport sys\nsys.exit()\n\n#html = f"{shap.getjs()}"\n# Bee swarm\n# .. note: unexpected algorithm matplotlib!\n# .. note: does not return an object!\nplot_bee = shap.plots.beeswarm(shap_values, show=False)\n\n# Sow\nprint("\nBEE")\nprint(plot_bee)\n\n#print(f)\n# Waterfall\n# .. note: not working!\n#shap.plots.waterfall(shap_values[0], max_display=14)\n\n# Force plot\n# .. note: not working!\nplot_force = shap.plots.force(explainer.expected_value,\n explainer.shap_values(X_train), X_train,\n matplotlib=False, show=False)\n\n# Show\nprint("\nFORCE:")\nprint(plot_force)\nprint(plot_force.html())\nprint(shap.save_html(\'e.html\', plot_force))\n' | .. code-block:: default :lineno-start: 8 # Generic import numpy as np import pandas as pd import matplotlib.pyplot as plt # Sklearn from sklearn.model_selection import train_test_split from sklearn.datasets import load_iris from sklearn.datasets import load_breast_cancer from sklearn.naive_bayes import GaussianNB from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier # Xgboost from xgboost import XGBClassifier # ---------------------------------------- # Load data # ---------------------------------------- # Seed seed = 0 # Load dataset bunch = load_iris() bunch = load_breast_cancer() features = list(bunch['feature_names']) # Create DataFrame data = pd.DataFrame(data=np.c_[bunch['data'], bunch['target']], columns=features + ['target']) # Create X, y X = data[bunch['feature_names']] y = data['target'] # Filter X = X.iloc[:500, :3] y = y.iloc[:500] # Split dataset X_train, X_test, y_train, y_test = \ train_test_split(X, y, random_state=seed) # ---------------------------------------- # Classifiers # ---------------------------------------- # Train classifier gnb = GaussianNB() llr = LogisticRegression() dtc = DecisionTreeClassifier(random_state=seed) rfc = RandomForestClassifier(random_state=seed) xgb = XGBClassifier( min_child_weight=0.005, eta= 0.05, gamma= 0.2, max_depth= 4, n_estimators= 100) # Select one clf = xgb # Fit clf.fit(X_train, y_train) # ---------------------------------------- # Find shap values # ---------------------------------------- # Import import shap """ # Create shap explainer if isinstance(clf, (DecisionTreeClassifier, RandomForestClassifier, XGBClassifier)): # Set Tree explainer explainer = shap.TreeExplainer(clf) elif isinstance(clf, int): # Set NN explainer explainer = shap.DeepExplainer(clf) else: # Set generic kernel explainer explainer = shap.KernelExplainer(clf.predict_proba, X_train) """ # Get generic explainer explainer = shap.Explainer(clf, X_train) # Show kernel type print("\nKernel type: %s" % type(explainer)) # Get shap values shap_values = explainer(X) print(shap_values) # For interactions!! # https://github.com/slundberg/shap/issues/501 # Get shap values #shap_values = \ # explainer.shap_values(X_train) #shap_interaction_values = \ # explainer.shap_interaction_values(X_train) # Show information print("shap_values (shape): %s" % \ str(shap_values.shape)) #print("shap_values_interaction (shape): %s" % \ # str(shap_interaction_values.shape)) # ---------------------------------------- # Visualize # ---------------------------------------- # Initialise shap.initjs() """ # Dependence plot shap.dependence_plot(0, shap_values, X_train, interaction_index=None, dot_size=5, alpha=0.5, color='#3F75BC', show=False) plt.tight_layout() """ print(explainer.shap_values(X_train)) # Summary plot plot_summary = shap.summary_plot( \ explainer.shap_values(X_train), X_train, cmap='viridis', show=False) plt.tight_layout() plt.show() print(plot_summary) import seaborn as sns sv = explainer.shap_values(X_train) sv = pd.DataFrame(sv, columns=X.columns) sv = sv.stack().reset_index() sv['val'] = X_train.stack().reset_index()[0] #import plotly.express as px #f = px.strip(data_frame=sv, x=0, y='level_1', color='val') #f.show() """ print(sv) #sns.swarmplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis') #sns.stripplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis') #plt.show() import sys sys.exit() #sns.swarmplot(x=) import sys sys.exit() #html = f"{shap.getjs()}" # Bee swarm # .. note: unexpected algorithm matplotlib! # .. note: does not return an object! plot_bee = shap.plots.beeswarm(shap_values, show=False) # Sow print("\nBEE") print(plot_bee) #print(f) # Waterfall # .. note: not working! #shap.plots.waterfall(shap_values[0], max_display=14) # Force plot # .. note: not working! plot_force = shap.plots.force(explainer.expected_value, explainer.shap_values(X_train), X_train, matplotlib=False, show=False) # Show print("\nFORCE:") print(plot_force) print(plot_force.html()) print(shap.save_html('e.html', plot_force)) """ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 2.068 seconds) .. _sphx_glr_download__examples_shap_plot_main01.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_main01.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_main01.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_