Note
Click here to download the full example code
04. Basic example
Out:
<IPython.core.display.HTML object>
0%| | 0/1 [00:00<?, ?it/s]
100%|##########| 1/1 [00:00<00:00, 118.41it/s]
6 # coding: utf-8
7
8 # In[1]:
9
10 ### using XGBoost model with SHAP
11
12 import numpy as np
13 import pandas as pd
14 import xgboost as xgb
15 import matplotlib.pyplot as plt
16
17 import shap
18
19 from sklearn.model_selection import train_test_split
20 from sklearn.datasets import make_regression
21
22 shap.initjs()
23
24
25 # In[2]:
26
27 ### make data
28 X, y = make_regression(n_samples=100, n_features=5,
29 n_informative=3, random_state=0, noise=4.0,
30 bias=10.0)
31 feature_names = ["x" + str(i+1) for i in range(0,5)]
32 data = pd.DataFrame(X, columns=feature_names)
33 data["target"] = y
34
35
36 # In[3]:
37
38 X_train, X_test, y_train, y_test = \
39 train_test_split(data[feature_names], ## predictors only
40 data.target,
41 test_size=0.30,
42 random_state=0)
43
44
45 # In[4]:
46
47 ### create and fit model
48 estimator = xgb.XGBRegressor()
49 estimator.fit(X_train, y_train)
50
51
52 # In[5]:
53
54 ## kernel shap sends data as numpy array which has no column names, so we fix it
55 def xgb_predict(data_asarray):
56 data_asframe = pd.DataFrame(data_asarray, columns=feature_names)
57 return estimator.predict(data_asframe)
58
59
60 # In[6]:
61
62 #### Kernel SHAP
63 X_summary = shap.kmeans(X_train, 10)
64 shap_kernel_explainer = shap.KernelExplainer(xgb_predict, X_summary)
65
66
67 # In[7]:
68
69 ## shapely values with kernel SHAP
70 shap_values_single = shap_kernel_explainer.shap_values(X_test.iloc[[5]])
71 shap.force_plot(shap_kernel_explainer.expected_value, shap_values_single, X_test.iloc[[5]])
72
73
74
75 # In[9]:
76
77 #### Tree SHAP
78 shap_tree_explainer = shap.TreeExplainer(estimator)
79
80
81 # In[10]:
82
83 # Deprecated error
84 ## shapely values with Tree SHAP
85 #shap_values_single = shap_tree_explainer.shap_values(X_test.iloc[[5]])
86 #shap.force_plot(shap_tree_explainer.expected_value, shap_values_single, X_test.iloc[[5]])
87
88 plt.show()
Total running time of the script: ( 0 minutes 0.245 seconds)