.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples\shap\plot_main05.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_shap_plot_main05.py: 05. Display SHAP for Sequential Data ================================================= This script provides an in-depth guide on visualizing SHAP values for sequential or time-series data, a common challenge when interpreting models like LSTMs or other RNNs. It explores various techniques to break down and display the complex, three-dimensional SHAP output (samples, timesteps, features). The script demonstrates several approaches: - **Slicing SHAP Data:** It shows how to use the standard ``shap.summary_plot`` by systematically slicing the data, visualizing feature importances either per-timestep or across all timesteps for a single feature. - **Custom Seaborn Plots:** It implements custom visualizations using ``seaborn.stripplot`` and ``seaborn.swarmplot`` to offer more granular control over the plot's appearance and layout. - **Advanced Coloring:** Helper functions are created to replicate SHAP's signature feature—coloring data points by their original feature value—allowing for richer interpretation in custom plots. This example is a valuable resource for anyone looking to move beyond default plots and create tailored, insightful SHAP visualizations for models that handle sequential data. .. GENERATED FROM PYTHON SOURCE LINES 29-168 .. code-block:: default :lineno-start: 30 # Libraries import shap import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt import matplotlib as mpl import matplotlib.colorbar import matplotlib.colors import matplotlib.cm from mpl_toolkits.axes_grid1 import make_axes_locatable try: __file__ TERMINAL = True except: TERMINAL = False # ------------------------ # Methods # ------------------------ def scalar_colormap(values, cmap, vmin, vmax): """This method creates a colormap based on values. Parameters ---------- values : array-like The values to create the corresponding colors cmap : str The colormap vmin, vmax : float The minimum and maximum possible values Returns ------- scalar colormap """ # Create scalar mappable norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True) mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) # Get color map colormap = sns.color_palette([mapper.to_rgba(i) for i in values]) # Return return colormap, norm def scalar_palette(values, cmap, vmin, vmax): """This method creates a colorpalette based on values. Parameters ---------- values : array-like The values to create the corresponding colors cmap : str The colormap vmin, vmax : float The minimum and maximum possible values Returns ------- scalar colormap """ # Create a matplotlib colormap from name #cmap = sns.light_palette(cmap, reverse=False, as_cmap=True) cmap = sns.color_palette(cmap, as_cmap=True) # Normalize to the range of possible values from df["c"] norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) # Create a color dictionary (value in c : color from colormap) colors = {} for cval in values: colors.update({cval : cmap(norm(cval))}) # Return return colors, norm def create_random_shap(samples, timesteps, features): """Create random LSTM data. .. note: No need to create the 3D matrix and then reshape to 2D. It would be possible to create directly the 2D matrix. Parameters ---------- samples: int The number of observations timesteps: int The number of time steps features: int The number of features Returns ------- Stacked matrix with the data. """ # .. note: Either perform a pre-processing step such as # normalization or generate the features within # the appropriate interval. # Create dataset x = np.random.randint(low=0, high=100, size=(samples, timesteps, features)) y = np.random.randint(low=0, high=2, size=samples).astype(float) i = np.vstack(np.dstack(np.indices((samples, timesteps)))) # Create DataFrame df = pd.DataFrame( data=np.hstack((i, x.reshape((-1, features)))), columns=['sample', 'timestep'] + \ ['f%s'%j for j in range(features)] ) df_stack = df.set_index(['sample', 'timestep']).stack() df_stack = df_stack df_stack.name = 'shap_values' df_stack = df_stack.to_frame() df_stack.index.names = ['sample', 'timestep', 'features'] df_stack = df_stack.reset_index() df_stack['feature_values'] = np.random.randint( low=0, high=100, size=df_stack.shape[0]) return df_stack def load_shap_file(): data = pd.read_csv('./data/shap.csv') data = data.iloc[: , 1:] #data.timestep = data.timestep - (data.timestep.nunique() - 1) return data .. GENERATED FROM PYTHON SOURCE LINES 169-170 Lets generate and/or load the shap values. .. GENERATED FROM PYTHON SOURCE LINES 170-196 .. code-block:: default :lineno-start: 171 # .. note: The right format to use for plotting depends # on the library we use. The data structure is # good when using seaborn # Load data data = create_random_shap(10, 6, 4) #data = load_shap_file() #data = data[data['sample'] < 100] shap_values = pd.pivot_table(data, values='shap_values', index=['sample', 'timestep'], columns=['features']) feature_values = pd.pivot_table(data, values='feature_values', index=['sample', 'timestep'], columns=['features']) # Show if TERMINAL: print("\nShow:") print(data) print(shap_values) print(feature_values) .. GENERATED FROM PYTHON SOURCE LINES 197-198 Let's see how data looks like .. GENERATED FROM PYTHON SOURCE LINES 198-200 .. code-block:: default :lineno-start: 198 data.head(10) .. raw:: html
sample timestep features shap_values feature_values
0 0 0 f0 61 65
1 0 0 f1 79 9
2 0 0 f2 10 46
3 0 0 f3 76 68
4 0 1 f0 49 4
5 0 1 f1 66 71
6 0 1 f2 24 86
7 0 1 f3 38 32
8 0 2 f0 68 0
9 0 2 f1 59 67


.. GENERATED FROM PYTHON SOURCE LINES 201-202 Let's see how shap_values looks like .. GENERATED FROM PYTHON SOURCE LINES 202-204 .. code-block:: default :lineno-start: 202 shap_values.iloc[:10, :5] .. raw:: html
features f0 f1 f2 f3
sample timestep
0 0 61.0 79.0 10.0 76.0
1 49.0 66.0 24.0 38.0
2 68.0 59.0 34.0 89.0
3 10.0 7.0 13.0 68.0
4 11.0 35.0 3.0 30.0
5 32.0 62.0 90.0 89.0
1 0 53.0 70.0 92.0 51.0
1 71.0 31.0 45.0 27.0
2 74.0 68.0 72.0 27.0
3 15.0 99.0 65.0 82.0


.. GENERATED FROM PYTHON SOURCE LINES 205-206 Let's see how feature_values looks like .. GENERATED FROM PYTHON SOURCE LINES 206-209 .. code-block:: default :lineno-start: 206 feature_values.iloc[:10, :5] .. raw:: html
features f0 f1 f2 f3
sample timestep
0 0 65.0 9.0 46.0 68.0
1 4.0 71.0 86.0 32.0
2 0.0 67.0 50.0 41.0
3 60.0 85.0 49.0 82.0
4 61.0 1.0 59.0 12.0
5 39.0 94.0 19.0 8.0
1 0 36.0 54.0 90.0 6.0
1 36.0 40.0 52.0 74.0
2 6.0 25.0 79.0 17.0
3 51.0 6.0 13.0 70.0


.. GENERATED FROM PYTHON SOURCE LINES 210-214 Display using ``shap.summary_plot`` ----------------------------------------------- The first option is to use the ``shap`` library to plot the results. .. GENERATED FROM PYTHON SOURCE LINES 214-223 .. code-block:: default :lineno-start: 215 # Let's define/extract some useful variables. N = 4 # max loops filter TIMESTEPS = len(shap_values.index.unique(level='timestep')) # number of timesteps SAMPLES = len(shap_values.index.unique(level='sample')) # number of samples shap_min = data.shap_values.min() shap_max = data.shap_values.max() .. GENERATED FROM PYTHON SOURCE LINES 224-225 Now, let's display the shap values for all features in each timestep. .. GENERATED FROM PYTHON SOURCE LINES 225-248 .. code-block:: default :lineno-start: 226 # For each timestep (visualise all features) for i, step in enumerate(range(TIMESTEPS)[:N]): # Show #print('%2d. %s' % (i, step)) # .. note: First option (commented) is only necessary if we work # with a numpy array. However, since we are using a DataFrame # with the timestep, we can index by that index level. # Compute indices #indice = np.arange(SAMPLES)*TIMESTEPS + step indice = shap_values.index.get_level_values('timestep') == i # Create auxiliary matrices shap_aux = shap_values.iloc[indice] feat_aux = feature_values.iloc[indice] # Display plt.figure() plt.title("Timestep: %s" % i) shap.summary_plot(shap_aux.to_numpy(), feat_aux, show=False) plt.xlim(shap_min, shap_max) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_001.png :alt: Timestep: 0 :srcset: /_examples/shap/images/sphx_glr_plot_main05_001.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_002.png :alt: Timestep: 1 :srcset: /_examples/shap/images/sphx_glr_plot_main05_002.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_003.png :alt: Timestep: 2 :srcset: /_examples/shap/images/sphx_glr_plot_main05_003.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_004.png :alt: Timestep: 3 :srcset: /_examples/shap/images/sphx_glr_plot_main05_004.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 249-250 Now, let's display the shap values for all timesteps of each feature. .. GENERATED FROM PYTHON SOURCE LINES 250-271 .. code-block:: default :lineno-start: 251 # For each feature (visualise all time-steps) for i, f in enumerate(shap_values.columns[:N]): # Show #print('%2d. %s' % (i, f)) # Create auxiliary matrices (select feature and reshape) shap_aux = shap_values.iloc[:, i] \ .to_numpy().reshape(-1, TIMESTEPS) feat_aux = feature_values.iloc[:, i] \ .to_numpy().reshape(-1, TIMESTEPS) feat_aux = pd.DataFrame(feat_aux, columns=['timestep %s'%j for j in range(TIMESTEPS)] ) # Show plt.figure() plt.title("Feature: %s" % f) shap.summary_plot(shap_aux, feat_aux, sort=False, show=False) plt.xlim(shap_min, shap_max) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_005.png :alt: Feature: f0 :srcset: /_examples/shap/images/sphx_glr_plot_main05_005.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_006.png :alt: Feature: f1 :srcset: /_examples/shap/images/sphx_glr_plot_main05_006.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_007.png :alt: Feature: f2 :srcset: /_examples/shap/images/sphx_glr_plot_main05_007.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_008.png :alt: Feature: f3 :srcset: /_examples/shap/images/sphx_glr_plot_main05_008.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 272-274 .. note:: If y-axis represents timesteps the ``sort`` parameter in the ``summary_plot`` function is set to False. .. GENERATED FROM PYTHON SOURCE LINES 276-284 Display using ``sns.stripplot`` ------------------------------- .. warning:: This method seems to be quite slow. Let's display the shap values for each feature and all time steps. In contrast to the previous example, the timesteps are now displayed on the x-axis and the y-axis contains the shap values. .. GENERATED FROM PYTHON SOURCE LINES 284-331 .. code-block:: default :lineno-start: 286 def add_colorbar(fig, cmap, norm): """""" divider = make_axes_locatable(plt.gca()) ax_cb = divider.new_horizontal(size="5%", pad=0.05) fig.add_axes(ax_cb) cb1 = matplotlib.colorbar.ColorbarBase(ax_cb, cmap=cmap, norm=norm, orientation='vertical') # Loop for i, (name, df) in enumerate(data.groupby('features')): # Get colormap values = df.feature_values cmap, norm = scalar_palette(values=values, cmap='coolwarm', vmin=values.min(), vmax=values.max()) print(df) # Display fig, ax = plt.subplots() ax = sns.stripplot(x='timestep', y='shap_values', hue='feature_values', palette=cmap, data=df, ax=ax) # Needed for older matplotlib versions cmap = matplotlib.cm.get_cmap('coolwarm') # Configure axes plt.title(name) plt.legend([], [], frameon=False) ax.invert_xaxis() add_colorbar(plt.gcf(), cmap, norm) # End if int(i) > N: break # Show plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_009.png :alt: f0 :srcset: /_examples/shap/images/sphx_glr_plot_main05_009.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_010.png :alt: f1 :srcset: /_examples/shap/images/sphx_glr_plot_main05_010.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_011.png :alt: f2 :srcset: /_examples/shap/images/sphx_glr_plot_main05_011.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_012.png :alt: f3 :srcset: /_examples/shap/images/sphx_glr_plot_main05_012.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none sample timestep features shap_values feature_values 0 0 0 f0 61 65 4 0 1 f0 49 4 8 0 2 f0 68 0 12 0 3 f0 10 60 16 0 4 f0 11 61 20 0 5 f0 32 39 24 1 0 f0 53 36 28 1 1 f0 71 36 32 1 2 f0 74 6 36 1 3 f0 15 51 40 1 4 f0 56 13 44 1 5 f0 0 65 48 2 0 f0 24 66 52 2 1 f0 92 92 56 2 2 f0 50 15 60 2 3 f0 6 62 64 2 4 f0 43 45 68 2 5 f0 30 52 72 3 0 f0 28 36 76 3 1 f0 60 31 80 3 2 f0 26 64 84 3 3 f0 28 93 88 3 4 f0 82 29 92 3 5 f0 56 71 96 4 0 f0 37 68 100 4 1 f0 31 3 104 4 2 f0 57 95 108 4 3 f0 47 44 112 4 4 f0 89 52 116 4 5 f0 63 75 120 5 0 f0 30 86 124 5 1 f0 72 49 128 5 2 f0 20 40 132 5 3 f0 56 58 136 5 4 f0 64 26 140 5 5 f0 4 23 144 6 0 f0 20 85 148 6 1 f0 8 75 152 6 2 f0 50 93 156 6 3 f0 73 5 160 6 4 f0 98 98 164 6 5 f0 41 20 168 7 0 f0 45 80 172 7 1 f0 48 89 176 7 2 f0 31 51 180 7 3 f0 98 69 184 7 4 f0 3 5 188 7 5 f0 58 77 192 8 0 f0 26 18 196 8 1 f0 30 24 200 8 2 f0 92 72 204 8 3 f0 47 96 208 8 4 f0 9 45 212 8 5 f0 94 7 216 9 0 f0 64 86 220 9 1 f0 39 86 224 9 2 f0 27 13 228 9 3 f0 54 32 232 9 4 f0 4 54 236 9 5 f0 59 41 C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:315: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. sample timestep features shap_values feature_values 1 0 0 f1 79 9 5 0 1 f1 66 71 9 0 2 f1 59 67 13 0 3 f1 7 85 17 0 4 f1 35 1 21 0 5 f1 62 94 25 1 0 f1 70 54 29 1 1 f1 31 40 33 1 2 f1 68 25 37 1 3 f1 99 6 41 1 4 f1 68 10 45 1 5 f1 66 4 49 2 0 f1 27 43 53 2 1 f1 64 69 57 2 2 f1 46 22 61 2 3 f1 92 38 65 2 4 f1 76 74 69 2 5 f1 92 11 73 3 0 f1 51 27 77 3 1 f1 99 52 81 3 2 f1 60 83 85 3 3 f1 34 65 89 3 4 f1 45 46 93 3 5 f1 86 88 97 4 0 f1 98 11 101 4 1 f1 2 46 105 4 2 f1 69 77 109 4 3 f1 14 55 113 4 4 f1 17 86 117 4 5 f1 90 0 121 5 0 f1 40 22 125 5 1 f1 65 93 129 5 2 f1 62 42 133 5 3 f1 67 50 137 5 4 f1 54 17 141 5 5 f1 45 69 145 6 0 f1 31 22 149 6 1 f1 44 45 153 6 2 f1 47 52 157 6 3 f1 1 8 161 6 4 f1 8 56 165 6 5 f1 27 79 169 7 0 f1 38 3 173 7 1 f1 53 7 177 7 2 f1 21 62 181 7 3 f1 34 62 185 7 4 f1 6 7 189 7 5 f1 12 41 193 8 0 f1 71 90 197 8 1 f1 81 82 201 8 2 f1 71 50 205 8 3 f1 41 12 209 8 4 f1 62 34 213 8 5 f1 96 93 217 9 0 f1 42 30 221 9 1 f1 35 52 225 9 2 f1 35 44 229 9 3 f1 51 56 233 9 4 f1 15 88 237 9 5 f1 4 19 C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:315: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. sample timestep features shap_values feature_values 2 0 0 f2 10 46 6 0 1 f2 24 86 10 0 2 f2 34 50 14 0 3 f2 13 49 18 0 4 f2 3 59 22 0 5 f2 90 19 26 1 0 f2 92 90 30 1 1 f2 45 52 34 1 2 f2 72 79 38 1 3 f2 65 13 42 1 4 f2 38 74 46 1 5 f2 24 97 50 2 0 f2 75 1 54 2 1 f2 40 38 58 2 2 f2 34 7 62 2 3 f2 34 73 66 2 4 f2 1 65 70 2 5 f2 41 9 74 3 0 f2 5 12 78 3 1 f2 10 48 82 3 2 f2 64 97 86 3 3 f2 14 5 90 3 4 f2 51 32 94 3 5 f2 83 95 98 4 0 f2 82 88 102 4 1 f2 85 55 106 4 2 f2 12 68 110 4 3 f2 5 0 114 4 4 f2 34 60 118 4 5 f2 95 23 122 5 0 f2 52 1 126 5 1 f2 0 12 130 5 2 f2 46 39 134 5 3 f2 74 36 138 5 4 f2 8 6 142 5 5 f2 67 1 146 6 0 f2 2 41 150 6 1 f2 17 34 154 6 2 f2 25 62 158 6 3 f2 23 70 162 6 4 f2 19 52 166 6 5 f2 74 7 170 7 0 f2 18 66 174 7 1 f2 81 66 178 7 2 f2 87 98 182 7 3 f2 24 10 186 7 4 f2 76 28 190 7 5 f2 54 7 194 8 0 f2 88 85 198 8 1 f2 47 1 202 8 2 f2 76 87 206 8 3 f2 73 98 210 8 4 f2 82 97 214 8 5 f2 11 98 218 9 0 f2 57 39 222 9 1 f2 96 20 226 9 2 f2 86 45 230 9 3 f2 61 42 234 9 4 f2 32 3 238 9 5 f2 91 2 C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:315: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. sample timestep features shap_values feature_values 3 0 0 f3 76 68 7 0 1 f3 38 32 11 0 2 f3 89 41 15 0 3 f3 68 82 19 0 4 f3 30 12 23 0 5 f3 89 8 27 1 0 f3 51 6 31 1 1 f3 27 74 35 1 2 f3 27 17 39 1 3 f3 82 70 43 1 4 f3 89 8 47 1 5 f3 42 74 51 2 0 f3 32 29 55 2 1 f3 87 64 59 2 2 f3 75 48 63 2 3 f3 99 94 67 2 4 f3 79 80 71 2 5 f3 51 66 75 3 0 f3 38 94 79 3 1 f3 54 56 83 3 2 f3 74 84 87 3 3 f3 46 22 91 3 4 f3 59 45 95 3 5 f3 59 82 99 4 0 f3 41 32 103 4 1 f3 76 46 107 4 2 f3 80 65 111 4 3 f3 87 61 115 4 4 f3 95 64 119 4 5 f3 49 46 123 5 0 f3 98 38 127 5 1 f3 55 50 131 5 2 f3 70 48 135 5 3 f3 74 47 139 5 4 f3 13 83 143 5 5 f3 59 52 147 6 0 f3 18 51 151 6 1 f3 9 5 155 6 2 f3 79 31 159 6 3 f3 67 48 163 6 4 f3 31 21 167 6 5 f3 43 44 171 7 0 f3 19 99 175 7 1 f3 77 10 179 7 2 f3 71 64 183 7 3 f3 80 37 187 7 4 f3 17 24 191 7 5 f3 30 95 195 8 0 f3 34 74 199 8 1 f3 45 6 203 8 2 f3 95 77 207 8 3 f3 84 30 211 8 4 f3 18 46 215 8 5 f3 40 90 219 9 0 f3 99 66 223 9 1 f3 29 23 227 9 2 f3 78 71 231 9 3 f3 95 59 235 9 4 f3 87 36 239 9 5 f3 27 4 C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:315: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:328: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown .. GENERATED FROM PYTHON SOURCE LINES 332-341 Display using ``sns.swarmplot`` ------------------------------- .. note: If the number of samples is too high, the points overlap and are ignored by the ``swarmplot`` library. In such scenario it is better to use ``stripplot``. Let's display the shap values for each timestep. .. GENERATED FROM PYTHON SOURCE LINES 341-418 .. code-block:: default :lineno-start: 342 # Loop for i, (name, df) in enumerate(data.groupby('features')): # Get colormap values = df.feature_values cmap, norm = scalar_palette(values=values, cmap='coolwarm', vmin=values.min(), vmax=values.max()) # Display fig, ax = plt.subplots() ax = sns.swarmplot(x='timestep', y='shap_values', hue='feature_values', palette=cmap, data=df, size=2, ax=ax) # Needed for older matplotlib versions cmap = matplotlib.cm.get_cmap('coolwarm') # Configure axes plt.title(name) plt.legend([], [], frameon=False) ax.invert_xaxis() add_colorbar(plt.gcf(), cmap, norm) # End if int(i) > N: break # Show plt.show() """ sns.set_theme(style="ticks") # Create a dataset with many short random walks rs = np.random.RandomState(4) pos = rs.randint(-1, 2, (20, 5)).cumsum(axis=1) pos -= pos[:, 0, np.newaxis] step = np.tile(range(5), 20) walk = np.repeat(range(20), 5) df = pd.DataFrame(np.c_[pos.flat, step, walk], columns=["position", "step", "walk"]) # Initialize a grid of plots with an Axes for each walk #grid = sns.FacetGrid(df_stack, col="walk", hue="f", palette="tab20c", # col_wrap=4, height=1.5) grid = sns.FacetGrid(df_stack, hue="f", palette="tab20c", height=1.5) # Draw a horizontal line to show the starting point grid.refline(y=0, linestyle=":") # Draw a line plot to show the trajectory of each random walk grid.map(plt.plot, "t", "value", marker="o") # Adjust the tick positions and labels grid.set(xticks=np.arange(5), yticks=[-3, 3], xlim=(-.5, 4.5), ylim=(-3.5, 3.5)) # Adjust the arrangement of the plots grid.fig.tight_layout(w_pad=1) """ #plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_013.png :alt: f0 :srcset: /_examples/shap/images/sphx_glr_plot_main05_013.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_014.png :alt: f1 :srcset: /_examples/shap/images/sphx_glr_plot_main05_014.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_015.png :alt: f2 :srcset: /_examples/shap/images/sphx_glr_plot_main05_015.png :class: sphx-glr-multi-img * .. image-sg:: /_examples/shap/images/sphx_glr_plot_main05_016.png :alt: f3 :srcset: /_examples/shap/images/sphx_glr_plot_main05_016.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:361: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:361: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:361: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:361: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead. C:\Users\kelda\Desktop\repositories\github\python-spare-code\main\examples\shap\plot_main05.py:374: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown '\nsns.set_theme(style="ticks")\n\n# Create a dataset with many short random walks\nrs = np.random.RandomState(4)\npos = rs.randint(-1, 2, (20, 5)).cumsum(axis=1)\npos -= pos[:, 0, np.newaxis]\nstep = np.tile(range(5), 20)\nwalk = np.repeat(range(20), 5)\ndf = pd.DataFrame(np.c_[pos.flat, step, walk],\n columns=["position", "step", "walk"])\n# Initialize a grid of plots with an Axes for each walk\n#grid = sns.FacetGrid(df_stack, col="walk", hue="f", palette="tab20c",\n# col_wrap=4, height=1.5)\n\ngrid = sns.FacetGrid(df_stack, hue="f",\n palette="tab20c", height=1.5)\n\n# Draw a horizontal line to show the starting point\ngrid.refline(y=0, linestyle=":")\n\n# Draw a line plot to show the trajectory of each random walk\ngrid.map(plt.plot, "t", "value", marker="o")\n\n# Adjust the tick positions and labels\ngrid.set(xticks=np.arange(5), yticks=[-3, 3],\n xlim=(-.5, 4.5), ylim=(-3.5, 3.5))\n\n# Adjust the arrangement of the plots\ngrid.fig.tight_layout(w_pad=1)\n\n' .. GENERATED FROM PYTHON SOURCE LINES 419-422 Display using ``sns.FacetGrid`` ------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 422-428 .. code-block:: default :lineno-start: 423 #g = sns.FacetGrid(df_stack, col="f", hue='original') #g.map(sns.swarmplot, "t", "value", alpha=.7) #g.add_legend() .. GENERATED FROM PYTHON SOURCE LINES 429-432 Display using ``shap.beeswarm`` ------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 432-437 .. code-block:: default :lineno-start: 433 # REF: https://github.com/slundberg/shap/blob/master/shap/plots/_beeswarm.py # # .. note: It needs a kernel explainer, and while it works with # common kernels (plot_main07.py) it does not work with # the DeepKernel for some reason (mask related). .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 3.032 seconds) .. _sphx_glr_download__examples_shap_plot_main05.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_main05.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_main05.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_