01. Basic example

plot main01


Kernel type: <class 'shap.explainers._tree.Tree'>
.values =
array([[ 0.01,  1.37, -4.46],
       [-0.04, -0.05, -4.71],
       [-0.04, -0.67, -4.25],
       [ 0.35,  0.36,  2.19],
       [-0.03, -0.06, -4.72],
       [-0.04, -0.67, -4.25]])

.base_values =
array([0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
       0.35, 0.35, 0.35, 0.35, 0.35])

.data =
array([[ 17.99,  10.38, 122.8 ],
       [ 20.57,  17.77, 132.9 ],
       [ 19.69,  21.25, 130.  ],
       [ 12.47,  17.31,  80.45],
       [ 18.49,  17.52, 121.3 ],
       [ 20.59,  21.24, 137.8 ]])
shap_values (shape): (500, 3)
<IPython.core.display.HTML object>
[[-0.09 -0.58 -4.2 ]
 [-0.57  0.93  3.57]
 [ 0.33  2.81 -2.97]
 [ 0.78 -0.04 -2.36]
 [ 0.25 -0.04  0.63]
 [ 0.22  1.57 -3.67]]

'\nprint(sv)\n#sns.swarmplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#sns.stripplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#plt.show()\nimport sys\nsys.exit()\n#sns.swarmplot(x=)\n\nimport sys\nsys.exit()\n\n#html = f"<head>{shap.getjs()}</head><body>"\n# Bee swarm\n# .. note: unexpected algorithm matplotlib!\n# .. note: does not return an object!\nplot_bee = shap.plots.beeswarm(shap_values, show=False)\n\n# Sow\nprint("\nBEE")\nprint(plot_bee)\n\n#print(f)\n# Waterfall\n# .. note: not working!\n#shap.plots.waterfall(shap_values[0], max_display=14)\n\n# Force plot\n# .. note: not working!\nplot_force = shap.plots.force(explainer.expected_value,\n    explainer.shap_values(X_train), X_train,\n    matplotlib=False, show=False)\n\n# Show\nprint("\nFORCE:")\nprint(plot_force)\nprint(plot_force.html())\nprint(shap.save_html(\'e.html\', plot_force))\n'

  8 # Generic
  9 import numpy as np
 10 import pandas as pd
 11 import matplotlib.pyplot as plt
 13 # Sklearn
 14 from sklearn.model_selection import train_test_split
 15 from sklearn.datasets import load_iris
 16 from sklearn.datasets import load_breast_cancer
 17 from sklearn.naive_bayes import GaussianNB
 18 from sklearn.linear_model import LogisticRegression
 19 from sklearn.tree import DecisionTreeClassifier
 20 from sklearn.ensemble import RandomForestClassifier
 22 # Xgboost
 23 from xgboost import XGBClassifier
 25 # ----------------------------------------
 26 # Load data
 27 # ----------------------------------------
 28 # Seed
 29 seed = 0
 31 # Load dataset
 32 bunch = load_iris()
 33 bunch = load_breast_cancer()
 34 features = list(bunch['feature_names'])
 36 # Create DataFrame
 37 data = pd.DataFrame(data=np.c_[bunch['data'], bunch['target']],
 38                     columns=features + ['target'])
 40 # Create X, y
 41 X = data[bunch['feature_names']]
 42 y = data['target']
 44 # Filter
 45 X = X.iloc[:500, :3]
 46 y = y.iloc[:500]
 49 # Split dataset
 50 X_train, X_test, y_train, y_test = \
 51     train_test_split(X, y, random_state=seed)
 54 # ----------------------------------------
 55 # Classifiers
 56 # ----------------------------------------
 57 # Train classifier
 58 gnb = GaussianNB()
 59 llr = LogisticRegression()
 60 dtc = DecisionTreeClassifier(random_state=seed)
 61 rfc = RandomForestClassifier(random_state=seed)
 62 xgb = XGBClassifier(
 63     min_child_weight=0.005,
 64     eta= 0.05, gamma= 0.2,
 65     max_depth= 4,
 66     n_estimators= 100)
 68 # Select one
 69 clf = xgb
 71 # Fit
 72 clf.fit(X_train, y_train)
 74 # ----------------------------------------
 75 # Find shap values
 76 # ----------------------------------------
 77 # Import
 78 import shap
 80 """
 81 # Create shap explainer
 82 if isinstance(clf,
 83     (DecisionTreeClassifier,
 84      RandomForestClassifier,
 85      XGBClassifier)):
 86     # Set Tree explainer
 87     explainer = shap.TreeExplainer(clf)
 88 elif isinstance(clf, int):
 89     # Set NN explainer
 90     explainer = shap.DeepExplainer(clf)
 91 else:
 92     # Set generic kernel explainer
 93     explainer = shap.KernelExplainer(clf.predict_proba, X_train)
 94 """
 96 # Get generic explainer
 97 explainer = shap.Explainer(clf, X_train)
 99 # Show kernel type
100 print("\nKernel type: %s" % type(explainer))
102 # Get shap values
103 shap_values = explainer(X)
105 print(shap_values)
107 # For interactions!!
108 # https://github.com/slundberg/shap/issues/501
110 # Get shap values
111 #shap_values = \
112 #    explainer.shap_values(X_train)
113 #shap_interaction_values = \
114 #    explainer.shap_interaction_values(X_train)
116 # Show information
117 print("shap_values (shape): %s" % \
118       str(shap_values.shape))
119 #print("shap_values_interaction (shape): %s" % \
120 #      str(shap_interaction_values.shape))
123 # ----------------------------------------
124 # Visualize
125 # ----------------------------------------
126 # Initialise
127 shap.initjs()
129 """
130 # Dependence plot
131 shap.dependence_plot(0, shap_values,
132     X_train, interaction_index=None, dot_size=5,
133     alpha=0.5, color='#3F75BC', show=False)
134 plt.tight_layout()
135 """
137 print(explainer.shap_values(X_train))
139 # Summary plot
140 plot_summary = shap.summary_plot( \
141     explainer.shap_values(X_train),
142     X_train, cmap='viridis',
143     show=False)
145 plt.tight_layout()
146 plt.show()
148 print(plot_summary)
151 import seaborn as sns
152 sv = explainer.shap_values(X_train)
153 sv = pd.DataFrame(sv, columns=X.columns)
154 sv = sv.stack().reset_index()
155 sv['val'] = X_train.stack().reset_index()[0]
157 #import plotly.express as px
159 #f = px.strip(data_frame=sv, x=0, y='level_1', color='val')
160 #f.show()
162 """
163 print(sv)
164 #sns.swarmplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis')
165 #sns.stripplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis')
166 #plt.show()
167 import sys
168 sys.exit()
169 #sns.swarmplot(x=)
171 import sys
172 sys.exit()
174 #html = f"<head>{shap.getjs()}</head><body>"
175 # Bee swarm
176 # .. note: unexpected algorithm matplotlib!
177 # .. note: does not return an object!
178 plot_bee = shap.plots.beeswarm(shap_values, show=False)
180 # Sow
181 print("\nBEE")
182 print(plot_bee)
184 #print(f)
185 # Waterfall
186 # .. note: not working!
187 #shap.plots.waterfall(shap_values[0], max_display=14)
189 # Force plot
190 # .. note: not working!
191 plot_force = shap.plots.force(explainer.expected_value,
192     explainer.shap_values(X_train), X_train,
193     matplotlib=False, show=False)
195 # Show
196 print("\nFORCE:")
197 print(plot_force)
198 print(plot_force.html())
199 print(shap.save_html('e.html', plot_force))
200 """

Total running time of the script: ( 0 minutes 2.068 seconds)

Gallery generated by Sphinx-Gallery