.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples/shap/plot_main03.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_shap_plot_main03.py: 03. Basic example ================= .. GENERATED FROM PYTHON SOURCE LINES 6-188 .. image-sg:: /_examples/shap/images/sphx_glr_plot_main03_001.png :alt: plot main03 :srcset: /_examples/shap/images/sphx_glr_plot_main03_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Kernel type: [array([-0.08, -0.12, -0.18]), array([0.08, 0.12, 0.18])] '\nprint(explainer.shap_values(X_train))\n\n# Summary plot\nplot_summary = shap.summary_plot( explainer.shap_values(X_train),\n X_train, cmap=\'viridis\',\n show=False)\n\nplt.tight_layout()\nplt.show()\n\nprint(plot_summary)\n\n\nimport seaborn as sns\nsv = explainer.shap_values(X_train)\nsv = pd.DataFrame(sv, columns=X.columns)\nsv = sv.stack().reset_index()\nsv[\'val\'] = X_train.stack().reset_index()[0]\n\n#import plotly.express as px\n\n#f = px.strip(data_frame=sv, x=0, y=\'level_1\', color=\'val\')\n#f.show()\n\nprint(sv)\n#sns.swarmplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#sns.stripplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#plt.show()\nimport sys\nsys.exit()\n#sns.swarmplot(x=)\n\nimport sys\nsys.exit()\n\n#html = f"{shap.getjs()}"\n# Bee swarm\n# .. note: unexpected algorithm matplotlib!\n# .. note: does not return an object!\nplot_bee = shap.plots.beeswarm(shap_values, show=False)\n\n# Sow\nprint("\nBEE")\nprint(plot_bee)\n\n#print(f)\n# Waterfall\n# .. note: not working!\n#shap.plots.waterfall(shap_values[0], max_display=14)\n\n# Force plot\n# .. note: not working!\nplot_force = shap.plots.force(explainer.expected_value,\n explainer.shap_values(X_train), X_train,\n matplotlib=False, show=False)\n\n# Show\nprint("\nFORCE:")\nprint(plot_force)\nprint(plot_force.html())\nprint(shap.save_html(\'e.html\', plot_force))\n' | .. code-block:: default :lineno-start: 6 # Generic import numpy as np import pandas as pd import matplotlib.pyplot as plt # Sklearn from sklearn.model_selection import train_test_split from sklearn.datasets import load_iris from sklearn.datasets import load_breast_cancer from sklearn.naive_bayes import GaussianNB from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier # Xgboost from xgboost import XGBClassifier # ---------------------------------------- # Load data # ---------------------------------------- # Seed seed = 0 # Load dataset bunch = load_iris() bunch = load_breast_cancer() features = list(bunch['feature_names']) # Create DataFrame data = pd.DataFrame(data=np.c_[bunch['data'], bunch['target']], columns=features + ['target']) # Create X, y X = data[bunch['feature_names']] y = data['target'] # Filter X = X.iloc[:500, :3] y = y.iloc[:500] # Split dataset X_train, X_test, y_train, y_test = \ train_test_split(X, y, random_state=seed) # ---------------------------------------- # Classifiers # ---------------------------------------- # Train classifier gnb = GaussianNB() llr = LogisticRegression() dtc = DecisionTreeClassifier(random_state=seed) rfc = RandomForestClassifier(random_state=seed) xgb = XGBClassifier( min_child_weight=0.005, eta= 0.05, gamma= 0.2, max_depth= 4, n_estimators= 100) # Select one clf = rfc # Fit clf.fit(X_train, y_train) # ---------------------------------------- # Find shap values # ---------------------------------------- # Import import shap # Initialise shap.initjs() # Get generic explainer explainer = shap.Explainer(clf, X_train) # Show kernel type print("\nKernel type: %s" % type(explainer)) # Variables #rows = X_train.iloc[5, :] #shap = explainer.shap_values(row) # Get shap values shap_values = explainer.shap_values(X_train.iloc[5, :]) #shap_values = explainer.shap_values(X_test.iloc[5, :]) print(shap_values) # Force plot # .. note: not working! plot_force = shap.plots.force(explainer.expected_value[1], shap_values[1], X_train.iloc[5, :], matplotlib=True, show=True) plt.tight_layout() plt.show() """ import sys sys.exit() # ---------------------------------------- # Visualize # ---------------------------------------- # Initialise shap.initjs() """ """ # Dependence plot shap.dependence_plot(0, shap_values, X_train, interaction_index=None, dot_size=5, alpha=0.5, color='#3F75BC', show=False) plt.tight_layout() """ """ print(explainer.shap_values(X_train)) # Summary plot plot_summary = shap.summary_plot( \ explainer.shap_values(X_train), X_train, cmap='viridis', show=False) plt.tight_layout() plt.show() print(plot_summary) import seaborn as sns sv = explainer.shap_values(X_train) sv = pd.DataFrame(sv, columns=X.columns) sv = sv.stack().reset_index() sv['val'] = X_train.stack().reset_index()[0] #import plotly.express as px #f = px.strip(data_frame=sv, x=0, y='level_1', color='val') #f.show() print(sv) #sns.swarmplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis') #sns.stripplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis') #plt.show() import sys sys.exit() #sns.swarmplot(x=) import sys sys.exit() #html = f"{shap.getjs()}" # Bee swarm # .. note: unexpected algorithm matplotlib! # .. note: does not return an object! plot_bee = shap.plots.beeswarm(shap_values, show=False) # Sow print("\nBEE") print(plot_bee) #print(f) # Waterfall # .. note: not working! #shap.plots.waterfall(shap_values[0], max_display=14) # Force plot # .. note: not working! plot_force = shap.plots.force(explainer.expected_value, explainer.shap_values(X_train), X_train, matplotlib=False, show=False) # Show print("\nFORCE:") print(plot_force) print(plot_force.html()) print(shap.save_html('e.html', plot_force)) """ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.422 seconds) .. _sphx_glr_download__examples_shap_plot_main03.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_main03.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_main03.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_