03. Basic example

plot main03

Out:

<IPython.core.display.HTML object>

Kernel type: <class 'shap.explainers._tree.Tree'>
[array([-0.08, -0.12, -0.18]), array([0.08, 0.12, 0.18])]

'\nprint(explainer.shap_values(X_train))\n\n# Summary plot\nplot_summary = shap.summary_plot(     explainer.shap_values(X_train),\n    X_train, cmap=\'viridis\',\n    show=False)\n\nplt.tight_layout()\nplt.show()\n\nprint(plot_summary)\n\n\nimport seaborn as sns\nsv = explainer.shap_values(X_train)\nsv = pd.DataFrame(sv, columns=X.columns)\nsv = sv.stack().reset_index()\nsv[\'val\'] = X_train.stack().reset_index()[0]\n\n#import plotly.express as px\n\n#f = px.strip(data_frame=sv, x=0, y=\'level_1\', color=\'val\')\n#f.show()\n\nprint(sv)\n#sns.swarmplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#sns.stripplot(data=sv, x=0, y=\'level_1\', color=\'viridis\', palette=\'viridis\')\n#plt.show()\nimport sys\nsys.exit()\n#sns.swarmplot(x=)\n\nimport sys\nsys.exit()\n\n#html = f"<head>{shap.getjs()}</head><body>"\n# Bee swarm\n# .. note: unexpected algorithm matplotlib!\n# .. note: does not return an object!\nplot_bee = shap.plots.beeswarm(shap_values, show=False)\n\n# Sow\nprint("\nBEE")\nprint(plot_bee)\n\n#print(f)\n# Waterfall\n# .. note: not working!\n#shap.plots.waterfall(shap_values[0], max_display=14)\n\n# Force plot\n# .. note: not working!\nplot_force = shap.plots.force(explainer.expected_value,\n    explainer.shap_values(X_train), X_train,\n    matplotlib=False, show=False)\n\n# Show\nprint("\nFORCE:")\nprint(plot_force)\nprint(plot_force.html())\nprint(shap.save_html(\'e.html\', plot_force))\n'

  6 # Generic
  7 import numpy as np
  8 import pandas as pd
  9 import matplotlib.pyplot as plt
 10
 11 # Sklearn
 12 from sklearn.model_selection import train_test_split
 13 from sklearn.datasets import load_iris
 14 from sklearn.datasets import load_breast_cancer
 15 from sklearn.naive_bayes import GaussianNB
 16 from sklearn.linear_model import LogisticRegression
 17 from sklearn.tree import DecisionTreeClassifier
 18 from sklearn.ensemble import RandomForestClassifier
 19
 20 # Xgboost
 21 from xgboost import XGBClassifier
 22
 23 # ----------------------------------------
 24 # Load data
 25 # ----------------------------------------
 26 # Seed
 27 seed = 0
 28
 29 # Load dataset
 30 bunch = load_iris()
 31 bunch = load_breast_cancer()
 32 features = list(bunch['feature_names'])
 33
 34 # Create DataFrame
 35 data = pd.DataFrame(data=np.c_[bunch['data'], bunch['target']],
 36                     columns=features + ['target'])
 37
 38 # Create X, y
 39 X = data[bunch['feature_names']]
 40 y = data['target']
 41
 42 # Filter
 43 X = X.iloc[:500, :3]
 44 y = y.iloc[:500]
 45
 46
 47 # Split dataset
 48 X_train, X_test, y_train, y_test = \
 49     train_test_split(X, y, random_state=seed)
 50
 51
 52 # ----------------------------------------
 53 # Classifiers
 54 # ----------------------------------------
 55 # Train classifier
 56 gnb = GaussianNB()
 57 llr = LogisticRegression()
 58 dtc = DecisionTreeClassifier(random_state=seed)
 59 rfc = RandomForestClassifier(random_state=seed)
 60 xgb = XGBClassifier(
 61     min_child_weight=0.005,
 62     eta= 0.05, gamma= 0.2,
 63     max_depth= 4,
 64     n_estimators= 100)
 65
 66 # Select one
 67 clf = rfc
 68
 69 # Fit
 70 clf.fit(X_train, y_train)
 71
 72 # ----------------------------------------
 73 # Find shap values
 74 # ----------------------------------------
 75 # Import
 76 import shap
 77
 78 # Initialise
 79 shap.initjs()
 80
 81 # Get generic explainer
 82 explainer = shap.Explainer(clf, X_train)
 83
 84 # Show kernel type
 85 print("\nKernel type: %s" % type(explainer))
 86
 87 # Variables
 88 #rows = X_train.iloc[5, :]
 89 #shap = explainer.shap_values(row)
 90
 91 # Get shap values
 92 shap_values = explainer.shap_values(X_train.iloc[5, :])
 93 #shap_values = explainer.shap_values(X_test.iloc[5, :])
 94
 95 print(shap_values)
 96
 97 # Force plot
 98 # .. note: not working!
 99 plot_force = shap.plots.force(explainer.expected_value[1],
100     shap_values[1], X_train.iloc[5, :],
101     matplotlib=True, show=True)
102
103 plt.tight_layout()
104 plt.show()
105
106 """
107 import sys
108 sys.exit()
109
110 # ----------------------------------------
111 # Visualize
112 # ----------------------------------------
113 # Initialise
114 shap.initjs()
115 """
116 """
117 # Dependence plot
118 shap.dependence_plot(0, shap_values,
119     X_train, interaction_index=None, dot_size=5,
120     alpha=0.5, color='#3F75BC', show=False)
121 plt.tight_layout()
122 """
123
124 """
125 print(explainer.shap_values(X_train))
126
127 # Summary plot
128 plot_summary = shap.summary_plot( \
129     explainer.shap_values(X_train),
130     X_train, cmap='viridis',
131     show=False)
132
133 plt.tight_layout()
134 plt.show()
135
136 print(plot_summary)
137
138
139 import seaborn as sns
140 sv = explainer.shap_values(X_train)
141 sv = pd.DataFrame(sv, columns=X.columns)
142 sv = sv.stack().reset_index()
143 sv['val'] = X_train.stack().reset_index()[0]
144
145 #import plotly.express as px
146
147 #f = px.strip(data_frame=sv, x=0, y='level_1', color='val')
148 #f.show()
149
150 print(sv)
151 #sns.swarmplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis')
152 #sns.stripplot(data=sv, x=0, y='level_1', color='viridis', palette='viridis')
153 #plt.show()
154 import sys
155 sys.exit()
156 #sns.swarmplot(x=)
157
158 import sys
159 sys.exit()
160
161 #html = f"<head>{shap.getjs()}</head><body>"
162 # Bee swarm
163 # .. note: unexpected algorithm matplotlib!
164 # .. note: does not return an object!
165 plot_bee = shap.plots.beeswarm(shap_values, show=False)
166
167 # Sow
168 print("\nBEE")
169 print(plot_bee)
170
171 #print(f)
172 # Waterfall
173 # .. note: not working!
174 #shap.plots.waterfall(shap_values[0], max_display=14)
175
176 # Force plot
177 # .. note: not working!
178 plot_force = shap.plots.force(explainer.expected_value,
179     explainer.shap_values(X_train), X_train,
180     matplotlib=False, show=False)
181
182 # Show
183 print("\nFORCE:")
184 print(plot_force)
185 print(plot_force.html())
186 print(shap.save_html('e.html', plot_force))
187 """

Total running time of the script: ( 0 minutes 0.422 seconds)

Gallery generated by Sphinx-Gallery