Note
Click here to download the full example code
06.c Collateral Sensitivity Index (CRI)
Since the computation of the Collateral Sensitivity Index is quite computationally expensive, the results are saved into a .csv file so that can be easily loaded and displayed. This script shows a very basic graph of such information.
The generates a heatmap visualization for a dataset related to collateral sensitivity. It uses the Seaborn library to plot the rectangular data as a color-encoded matrix. The code loads the data from a CSV file, creates mappings for categories and colors, and then plots the heatmap using the loaded data and color maps. It also includes annotations, colorbar axes, category patches, legend elements, and formatting options to enhance the visualization.
19 # Libraries
20 import numpy as np
21 import pandas as pd
22 import seaborn as sns
23 import matplotlib as mpl
24 import matplotlib.pyplot as plt
25
26 #
27 from pathlib import Path
28 from itertools import combinations
29 from matplotlib.colors import LogNorm, Normalize
30
31 # See https://matplotlib.org/devdocs/users/explain/customizing.html
32 mpl.rcParams['axes.titlesize'] = 8
33 mpl.rcParams['axes.labelsize'] = 8
34 mpl.rcParams['xtick.labelsize'] = 8
35 mpl.rcParams['ytick.labelsize'] = 8
36
37 try:
38 __file__
39 TERMINAL = True
40 except:
41 TERMINAL = False
42
43
44 # --------------------------------
45 # Methods
46 # --------------------------------
47 def _check_ax_ay_equal(ax, ay):
48 return ax==ay
49
50 def _check_ax_ay_greater(ax, ay):
51 return ax>ay
52
53 # --------------------------------
54 # Constants
55 # --------------------------------
56 # Possible cmaps
57 # https://r02b.github.io/seaborn_palettes/
58 # Diverging: coolwarm, RdBu_r, vlag
59 # Others: bone, gray, pink, twilight
60 cmap0 = 'coolwarm'
61 cmap1 = sns.light_palette("seagreen", as_cmap=True)
62 cmap2 = sns.color_palette("light:b", as_cmap=True)
63 cmap3 = sns.color_palette("vlag", as_cmap=True)
64 cmap4 = sns.diverging_palette(220, 20, as_cmap=True)
65 cmap5 = sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True) # no
66 cmap6 = sns.color_palette("light:#5A9", as_cmap=True)
67 cmap7 = sns.light_palette("#9dedcc", as_cmap=True)
68 cmap8 = sns.color_palette("YlGn", as_cmap=True)
69
70 # Figure size
71 figsize = (17, 4)
Let’s load the data
77 # Load data
78 path = Path('../../datasets/collateral-sensitivity/20230525-135511')
79 data = pd.read_csv(path / 'contingency.csv')
80 abxs = pd.read_csv(path / 'categories.csv')
81
82 # Format data
83 data = data.set_index(['specimen', 'o', 'ax', 'ay'])
84 data.RR = data.RR.fillna(0).astype(int)
85 data.RS = data.RS.fillna(0).astype(int)
86 data.SR = data.SR.fillna(0).astype(int)
87 data.SS = data.SS.fillna(0).astype(int)
88 data['samples'] = data.RR + data.RS + data.SR + data.SS
89
90 #data['samples'] = data.iloc[:, :4].sum(axis=1)
91
92 def filter_top_pairs(df, n=5):
93 """Filter top n (Specimen, Organism) pairs."""
94 # Find top
95 top = df.groupby(level=[0, 1]) \
96 .samples.sum() \
97 .sort_values(ascending=False) \
98 .head(n)
99
100 # Filter
101 idx = pd.IndexSlice
102 a = top.index.get_level_values(0).unique()
103 b = top.index.get_level_values(1).unique()
104
105 # Return
106 return df.loc[idx[a, b, :, :]]
107
108 # Filter
109 data = filter_top_pairs(data, n=2)
110 data = data[data.samples > 500]
Lets see the data
114 if TERMINAL:
115 print("\n")
116 print("Number of samples: %s" % data.samples.sum())
117 print("Number of pairs: %s" % data.shape[0])
118 print("Data:")
119 print(data)
120 data.iloc[:7,:].dropna(axis=1, how='all')
Lets load the antimicrobial data and create color mapping variables
126 # Create dictionary to map category to color
127 labels = abxs.category
128 palette = sns.color_palette('Spectral', labels.nunique())
129 palette = sns.cubehelix_palette(labels.nunique(),
130 light=.9, dark=.1, reverse=True, start=1, rot=-2)
131 lookup = dict(zip(labels.unique(), palette))
132
133 # Create dictionary to map code to category
134 code2cat = dict(zip(abxs.antimicrobial_code, abxs.category))
Let’s display the information
140 # Loop
141 for i, df in data.groupby(level=[0, 1]):
142
143 # Drop level
144 df = df.droplevel(level=[0, 1])
145
146 # Check possible issues.
147 ax = df.index.get_level_values(0)
148 ay = df.index.get_level_values(1)
149 idx1 = _check_ax_ay_equal(ax, ay)
150 idx2 = _check_ax_ay_greater(ax, ay)
151
152 # Show
153 print("%25s. ax==ay => %5s | ax>ay => %5s" % \
154 (i, idx1.sum(), idx2.sum()))
155
156 # Re-index to have square matrix
157 abxs = set(ax) | set(ay)
158 index = pd.MultiIndex.from_product([abxs, abxs])
159
160 # Reformat MIS
161 mis = df['MIS'] \
162 .reindex(index, fill_value=np.nan) \
163 .unstack()
164
165 # Reformat samples
166 freq = df['samples'] \
167 .reindex(index, fill_value=0) \
168 .unstack()
169
170 # Combine in square matrix
171 m1 = mis.copy(deep=True).to_numpy()
172 m2 = freq.to_numpy()
173 il1 = np.tril_indices(mis.shape[1])
174 m1[il1] = m2.T[il1]
175 m = pd.DataFrame(m1,
176 index=mis.index, columns=mis.columns)
177
178 # .. note: This is the matrix that is used in previous
179 # samples to display the CRI and the count using
180 # the sns.heatmap function
181 # Save
182 #m.to_csv('%s'%str(i))
183
184 # Add frequency
185 top_n = df \
186 .sort_values('samples', ascending=False) \
187 .head(20).drop(columns='MIS') \
188 .dropna(axis=1, how='all')
189
190 # Draw
191 fig, axs = plt.subplots(nrows=1, ncols=4,
192 sharey=False, sharex=False, figsize=figsize,
193 gridspec_kw={'width_ratios': [2, 3, 3, 3.5]})
194
195 sns.heatmap(data=mis * 100, annot=False, linewidth=.5,
196 cmap='coolwarm', vmin=-70, vmax=70, center=0,
197 annot_kws={"size": 8}, square=True,
198 ax=axs[2], xticklabels=True, yticklabels=True)
199
200 sns.heatmap(data=freq, annot=False, linewidth=.5,
201 cmap='Blues', norm=LogNorm(),
202 annot_kws={"size": 8}, square=True,
203 ax=axs[1], xticklabels=True, yticklabels=True)
204
205 sns.heatmap(top_n,
206 annot=False, linewidth=0.5,
207 cmap='Blues', ax=axs[0], zorder=1,
208 vmin=None, vmax=None, center=None, robust=True,
209 square=False, xticklabels=True, yticklabels=True,
210 cbar_kws={
211 'use_gridspec': True,
212 'location': 'right'
213 }
214 )
215
216 # Display
217 masku = np.triu(np.ones_like(m))
218 maskl = np.tril(np.ones_like(m))
219 sns.heatmap(data=m, cmap=cmap8, mask=masku, ax=axs[3],
220 annot=False, linewidth=0.5, norm=LogNorm(),
221 annot_kws={"size": 8}, square=True, vmin=0)
222 sns.heatmap(data=m, cmap=cmap4, mask=maskl, ax=axs[3],
223 annot=False, linewidth=0.5, vmin=-0.7, vmax=0.7,
224 center=0, annot_kws={"size": 8}, square=True,
225 xticklabels=True, yticklabels=True)
226
227 # Configure axes
228 axs[0].set_title('Contingency')
229 axs[1].set_title('Number of samples')
230 axs[2].set_title('Collateral Sensitivity Index')
231 axs[3].set_title('Samples / Collateral Sensitivity')
232
233 # Add colors to xticklabels
234
235 #abxs = pd.read_csv('../../datasets/susceptibility-nhs/susceptibility-v0.0.1/antimicrobials.csv')##
236
237 #groups = dict(zip(abxs.antimicrobial_code, abxs.category))
238 #cmap = sns.color_palette("Spectral", abxs.category.nunique())
239 #colors = dict(zip(abxs.category, cmap))
240
241 # ------------------------------------------
242 # Add category colors on xtick labels
243 # ------------------------------------------
244 # Create colors
245 colors = m.columns.to_series().map(code2cat).map(lookup)
246
247 # Loop
248 for lbl in axs[3].get_xticklabels():
249 try:
250 x, y = lbl.get_position()
251 c = colors.to_dict().get(lbl.get_text(), 'k')
252 lbl.set_color(c)
253 lbl.set_weight('bold')
254
255 """
256 axs[3].annotate('', xy=(2000, 0),
257 #xytext=(0, -15 - axs[3].xaxis.labelpad),
258 xytext=(i.x, y)
259 xycoords=('data', 'axes fraction'),
260 textcoords='offset points',
261 ha='center', va='top',
262 bbox=dict(boxstyle='round', fc='none', ec='red'))
263 """
264 except Exception as e:
265 print(lbl.get_text(), e)
266
267 # Configure plot
268 plt.suptitle('%s - %s' % (i[0], i[1]))
269 plt.tight_layout()
270 plt.subplots_adjust(wspace=0.1)
271
272 # Exit loop
273 break
274
275 # Show
276 plt.show()
Out:
('URICUL', 'ECOL'). ax==ay => 18 | ax>ay => 0
Total running time of the script: ( 0 minutes 2.476 seconds)