Note
Click here to download the full example code
08. Patient therapy flow (sankey)
This Sankey diagram visualizes treatment pathways by showing the flow of therapies between consecutive days. The width of each path is proportional to the volume of these transitions, which should be interpreted based on the input data’s structure. If the data contains a single entry for a given therapy per subject per day, the flow represents the number of patients. However, if the data includes multiple entries for the same therapy on a given day, such as for each dose, then the flow represents the total number of applications.
It should be noted that the visualization’s level of detail can be modified directly in the code. This allows the trajectories to be displayed by specific drug, broader drug category, or the WHO AWaRe classification (Access, Watch, Reserve).
19 # Libraries
20 import pandas as pd
21 import numpy as np
22 import plotly.graph_objects as go
23 import matplotlib.cm as cm
24
25 try:
26 __file__
27 TERMINAL = True
28 except:
29 TERMINAL = False
30
31 # =============================================================================
32 # LOOKUP TABLES & CONSTANTS
33 # =============================================================================
34
35 def drug_to_aware_default():
36 """Create default drug to aware classes."""
37 return {
38 "amikacin": "Access", "amoxicillin": "Access", "co-amoxiclav": "Access",
39 "benzylpenicillin sodium": "Access", "cefalexin": "Access", "cefazolin": "Access",
40 "chloramphenicol": "Access", "clindamycin": "Access", "doxycycline": "Access",
41 "flucloxacillin": "Access", "gentamicin": "Access", "mecillinam": "Access",
42 "metronidazole": "Access", "nitrofurantoin": "Access", "pivmecillinam": "Access",
43 "co-trimoxazole": "Access", "trimethoprim": "Access", "azithromycin": "Watch",
44 "cefepime": "Watch", "ceftazidime": "Watch", "ceftriaxone": "Watch",
45 "cefuroxime": "Watch", "ciprofloxacin": "Watch", "clarithromycin": "Watch",
46 "ertapenem": "Watch", "erythromycin": "Watch", "fosfomycin": "Watch",
47 "levofloxacin": "Watch", "meropenem": "Watch", "moxifloxacin": "Watch",
48 "piperacillin-tazobactam": "Watch", "teicoplanin": "Watch", "temocillin": "Watch",
49 "vancomycin": "Watch", "aztreonam": "Reserve", "cefiderocol": "Reserve",
50 "ceftazidime/avibactam": "Reserve", "ceftolozane/tazobactam": "Reserve",
51 "colistin": "Reserve", "daptomycin": "Reserve", "linezolid": "Reserve",
52 "tazocin": "Watch",
53 'no therapy': 'No Therapy'
54 }
55
56 def drug_to_aware_class_winnie():
57 """Create drug to aware class as defined by winnie"""
58 ACCESS_LIST = [
59 "Amikacin", "Amoxicillin", "Amoxicillin (contains penicillin)",
60 "Amoxicillin/clavulanic-acid", "Amoxiclavulanic-acid",
61 "Benzylpenicillin sodium", "Co-amoxiclav", "Co-amoxiclav (contains penicillin)",
62 "Co-amoxiclav (contains penicillin) (anes)",
63 "Co-amoxiclav 400mg/57mg in 5ml oral suspension (contains penicillin)",
64 "co-amoxiclav 250mg/62mg in 5ml oral suspension (contains penicillin)",
65 "Cefalexin", "Cefazolin", "Chloramphenicol", "Clindamycin", "Doxycycline",
66 "Flucloxacillin (contains penicillin)", "Gentamicin", "Gentamicin (anes)",
67 "Mecillinam", "Metronidazole", "Metronidazole (anes)", "Metronidazole_IV",
68 "Metronidazole_oral", "Nitrofurantoin", "Co-amoxiclav in suspension",
69 "Pivmecillinam", "Pivmecillinam (contains penicillin)",
70 "Sulfamethoxazole/trimethoprim", "Trimethoprim", "co-trimoxazole"
71 ]
72 WATCH_LIST = [
73 "Azithromycin", "Cefepime", "Ceftazidime", "Ceftriaxone", "Cefuroxime",
74 "Cefuroxime (anes)", "Ciprofloxacin", "Clarithromycin", "Ertapenem",
75 "Erythromycin", "Fosfomycin", "Fosfomycin_oral", "Levofloxacin", "Meropenem",
76 "Moxifloxacin", "Piperacillin/tazobactam", "piperacillin-tazobactam (contains penicillin)",
77 "piperacillin + tazobactam (contains penicillin)", "Teicoplanin", "Teicoplanin (anes)",
78 "Temocillin", "Temocillin (contains penicillin)", "Vancomycin",
79 "Vancomycin (anes)", "Vancomycin (anes) 1g", "Vancomycin_IV", "Vancomycin_oral"
80 ]
81 RESERVE_LIST = [
82 "Aztreonam", "Cefiderocol", "Ceftazidime/avibactam", "Ceftolozane/tazobactam",
83 "Colistin", "Colistin_IV", "Dalbavancin", "Daptomycin", "Fosfomycin_IV",
84 "Iclaprim", "Linezolid", "Tigecycline"
85 ]
86
87 # Define lookup table as this is more efficient for lookup than searching
88 # through lists every time. We convert all drug names to lowercase to
89 # ensure case-insensitive matching.
90 d = {drug.lower(): "Access" for drug in ACCESS_LIST}
91 d.update({drug.lower(): "Watch" for drug in WATCH_LIST})
92 d.update({drug.lower(): "Reserve" for drug in RESERVE_LIST})
93 d.update({'No therapy'.lower(): 'No Therapy'})
94 return d
95
96 DRUG_TO_AWARE_CLASS = drug_to_aware_default()
97
98 # .. note:: Please be aware that this list is not exhaustive. Any antibiotics not assigned
99 # to a specific category are currently labeled as 'Unknown'. To ensure complete
100 # coverage, please expand the list to encompass all items.
101 # .. note:: Ensure keys are lowercase.
102
103 DRUG_TO_CLASS = {
104 "amoxicillin": "Penicillin", "flucloxacillin": "Penicillin", "pivmecillinam": "Penicillin",
105 "co-amoxiclav": "Penicillin and beta-lactamase inhibitor", "tazocin": "Penicillin and beta-lactamase inhibitor",
106 "piperacillin-tazobactam": "Penicillin and beta-lactamase inhibitor", "cefuroxime": "2nd Gen Cephalosporin",
107 "ceftriaxone": "3rd Gen Cephalosporin", "ceftazidime": "3rd Gen Cephalosporin",
108 "cefepime": "4th Gen Cephalosporin", "meropenem": "Carbapenem", "aztreonam": "Monobactam",
109 "gentamicin": "Aminoglycoside", "amikacin": "Aminoglycoside", "ciprofloxacin": "Fluoroquinolone",
110 "clarithromycin": "Macrolide", "erythromycin": "Macrolide", "vancomycin": "Glycopeptide",
111 "teicoplanin": "Glycopeptide", "linezolid": "Oxazolidinone", "metronidazole": "Nitroimidazole",
112 "doxycycline": "Tetracycline", "fosfomycin": "Phosphonic acid", "co-trimoxazole": "Folate pathway inhibitor",
113 "chloramphenicol": "Amphenicol", "temocillin": "Penicillin",
114 "no therapy": "No Therapy" # To keep no therapy (otherwise will be set to Unkown)
115 }
116
117 COLUMN_FOR_ANALYSIS = {
118 'drug': 'drug_norm',
119 'class': 'drug_class',
120 'aware': 'aware_class'
121 }
122
123
124 # =============================================================================
125 # HELPER FUNCTIONS
126 # =============================================================================
127
128 def create_validation_data():
129 """
130 Creates a small, specific dataset with two patients to visually validate Sankey logic.
131 - Tests single therapy -> combination therapy transition.
132 - Tests merging flows from different sources into one combination node.
133 - Tests combination therapy -> single therapy transition.
134 - Tests transitions to 'No Therapy' and terminating journeys.
135 """
136 records = [
137 # --- Patient 101: Longer journey with a combo ---
138 {'SUBJECT': 101, 'DAY_NUM': 1, 'drug_norm': 'amoxicillin'},
139 # Day 2 is a combination therapy
140 {'SUBJECT': 101, 'DAY_NUM': 2, 'drug_norm': 'amoxicillin'},
141 {'SUBJECT': 101, 'DAY_NUM': 2, 'drug_norm': 'linezolid'},
142 # Day 3 transitions out of the combo
143 {'SUBJECT': 101, 'DAY_NUM': 3, 'drug_norm': 'vancomycin'},
144 {'SUBJECT': 101, 'DAY_NUM': 4, 'drug_norm': 'No Therapy'},
145 {'SUBJECT': 101, 'DAY_NUM': 5, 'drug_norm': 'No Therapy'},
146 {'SUBJECT': 101, 'DAY_NUM': 7, 'drug_norm': 'vancomycin'},
147 {'SUBJECT': 101, 'DAY_NUM': 8, 'drug_norm': 'vancomycin'},
148
149 # --- Patient 102: Shorter journey, merges into the combo on Day 2 ---
150 {'SUBJECT': 102, 'DAY_NUM': 1, 'drug_norm': 'ciprofloxacin'},
151 # Day 2 is the same combination therapy as Patient 101
152 {'SUBJECT': 102, 'DAY_NUM': 2, 'drug_norm': 'amoxicillin'},
153 {'SUBJECT': 102, 'DAY_NUM': 2, 'drug_norm': 'linezolid'},
154 # Day 3 transitions to the same drug, but journey ends here
155 {'SUBJECT': 102, 'DAY_NUM': 3, 'drug_norm': 'vancomycin'},
156 ]
157 return pd.DataFrame(records)
158
159
160 def remove_days_randomly(df, missing_day_fraction=0.01):
161 """Remove some days randomly."""
162 if missing_day_fraction > 0:
163 # Identify the indices of rows that are NOT the first day for any patient.
164 droppable_indices = df[df['DAY_NUM'] != 1].index
165 # Calculate how many of these rows to drop
166 num_to_drop = int(len(droppable_indices) * missing_day_fraction)
167 # Randomly choose the indices to drop from the droppable list
168 indices_to_drop = np.random.choice(droppable_indices, size=num_to_drop, replace=False)
169 # Drop the selected rows to create the final DataFrame
170 df = df.drop(indices_to_drop)
171 return df
172
173
174 def create_synthetic_data(num_patients=50, max_los=7):
175 """
176 Creates synthetic patient antibiotic data with variable lengths of stay.
177
178 .. note:: It adds combined therapies to first 5.
179 .. note:: It removes some days with no therapy information.
180
181 Args:
182 num_patients: The number of patients to simulate.
183 max_los: The maximum possible length of stay for any patient.
184 """
185 antibiotics = list(DRUG_TO_AWARE_CLASS.keys()) # + ['No Therapy']
186 records = []
187 for pid in range(num_patients):
188 los = np.random.randint(2, max_los + 1)
189 current_therapy = np.random.choice(antibiotics)
190 records.append({'SUBJECT': pid, 'DAY_NUM': 1, 'drug_norm': current_therapy})
191 for day in range(2, los + 1):
192 if current_therapy == 'No Therapy':
193 new_therapy = np.random.choice(antibiotics,
194 p=[0.8] + [0.2 / (len(antibiotics)-1)] * (len(antibiotics)-1))
195 else:
196 other_abs = [ab for ab in antibiotics if ab != current_therapy]
197 new_therapy = np.random.choice([current_therapy, 'No Therapy'] + other_abs,
198 p=[0.6, 0.1] + [0.3/len(other_abs)]*len(other_abs))
199 records.append({'SUBJECT': pid, 'DAY_NUM': day, 'drug_norm': new_therapy})
200 current_therapy = new_therapy
201 # Add combination therapies for testing.
202 for pid in range(5):
203 records.append({'SUBJECT': pid, 'DAY_NUM': 2, 'drug_norm': 'amoxicillin'})
204 records.append({'SUBJECT': pid, 'DAY_NUM': 2, 'drug_norm': 'linezolid'})
205 # Remove days randomly for testing
206 df = remove_days_randomly(pd.DataFrame(records), missing_day_fraction=0.1)
207 # Return
208 return df
209
210
211 def find_missing_days(group: pd.DataFrame,
212 subject_col: str,
213 day_col: str):
214 """For a subject's data, return a DataFrame of missing days."""
215 min_day, max_day = group[day_col].min(), group[day_col].max()
216
217 # Use sets for a fast way to find what's missing
218 all_possible_days = set(range(min_day, max_day + 1))
219 existing_days = set(group[day_col])
220 missing_days = sorted(list(all_possible_days - existing_days))
221
222 # Return a new DataFrame containing just the missing rows
223 if missing_days:
224 return pd.DataFrame({
225 subject_col: group.name,
226 day_col: missing_days,
227 'drug_norm': 'No Therapy'
228 })
229 return None # Return nothing if no days are missing
230
231
232 def process_patient_data(df: pd.DataFrame,
233 level: str,
234 therapy_col: str,
235 subject_col: str,
236 day_col: str) -> pd.DataFrame:
237 """
238 Applies classifications and correctly aggregates drug combinations for all
239 analysis levels to make the data "pivot-ready".
240 """
241 df_copy = df.copy()
242
243 # Create data with missing days and 'No Therapy'
244 df_missing = df_copy.groupby(subject_col) \
245 .apply(find_missing_days, subject_col=subject_col, day_col=day_col) \
246 .reset_index(drop=True)
247
248 # Concatenate original data and missing data
249 df_copy = pd.concat([df_copy, df_missing]).reset_index(drop=True)
250
251 if level == 'class':
252 df_copy[therapy_col] = df_copy['drug_norm'].str.lower() \
253 .map(DRUG_TO_CLASS).fillna('Unknown')
254 # Aggregate classes for combination therapies into a single sorted string
255 return df_copy.groupby([subject_col, day_col])[therapy_col] \
256 .apply(lambda x: ' + '.join(sorted(x.unique()))) \
257 .reset_index()
258
259 elif level == 'aware':
260 df_copy[therapy_col] = df_copy['drug_norm'].str.lower() \
261 .map(DRUG_TO_AWARE_CLASS).fillna('Unknown')
262 # For combinations, resolve by picking the highest-order class
263 aware_order = {'Access': 0, 'Watch': 1, 'Reserve': 2, 'Unknown': -1, 'No Therapy': -2}
264 df_copy['aware_ordinal'] = df_copy[therapy_col].map(aware_order)
265 idx = df_copy.groupby([subject_col, day_col])['aware_ordinal'].idxmax()
266 return df_copy.loc[idx].drop(columns='aware_ordinal')
267
268 elif level == 'drug':
269 # Aggregates combination drugs into a single sorted string
270 return df_copy.groupby([subject_col, day_col])[therapy_col]\
271 .apply(lambda x: ' + '.join(sorted(x.unique()))) \
272 .reset_index()
273
274 return df_copy
275
276
277 def limit_therapies_to_top_n(df: pd.DataFrame, therapy_col: str, n: int) -> pd.DataFrame:
278 """Keeps the top N most frequent therapies and groups the rest into 'Other'."""
279 if n is None: return df
280 top_n = df[therapy_col].value_counts().nlargest(n).index.tolist()
281 if 'No Therapy' not in top_n:
282 top_n.append('No Therapy')
283 df[therapy_col] = df[therapy_col].where(df[therapy_col].isin(top_n), 'Other')
284 return df
285
286
287 def create_flow_from_patient_data_pivot(df: pd.DataFrame,
288 therapy_col: str,
289 subject_col: str,
290 day_col: str) -> pd.DataFrame:
291 """Transforms raw patient data into a Sankey-ready flow DataFrame."""
292 patient_journeys = df.pivot(index=subject_col, columns=day_col, values=therapy_col)
293 all_links = []
294 for i in range(1, patient_journeys.columns.max()):
295 day_from, day_to = i, i + 1
296 if day_from in patient_journeys.columns and day_to in patient_journeys.columns:
297 transition_df = patient_journeys[[day_from, day_to]].dropna()
298 links = transition_df.value_counts().reset_index(name='count')
299 links.rename(columns={day_from: 'item_from', day_to: 'item_to'}, inplace=True)
300 links['day_from'], links['day_to'] = day_from, day_to
301 all_links.append(links)
302 if not all_links:
303 return pd.DataFrame(columns=['day_from', 'item_from', 'day_to', 'item_to', 'count'])
304 return pd.concat(all_links, ignore_index=True)
305
306
307 def create_comprehensive_color_map(df: pd.DataFrame) -> dict:
308 """Generates a consistent color map, including special colors for specified categories."""
309 therapies = sorted(list(set(np.concatenate([df['item_from'].unique(), df['item_to'].unique()]))))
310 color_map = {}
311 special_colors = {'No Therapy': '#aaaaaa', 'Other': '#d3d3d3', 'Unknown': '#e5e5e5'}
312 for category, color in special_colors.items():
313 if category in therapies:
314 color_map[category] = color
315 therapies.remove(category)
316 if not therapies: return color_map
317 colormap = cm.get_cmap('tab20b', len(therapies))
318 hex_colors = ['#%02x%02x%02x' % (int(r*255), int(g*255), int(b*255)) \
319 for r, g, b, a in colormap(np.linspace(0, 1, len(therapies)))]
320 for therapy, color in zip(therapies, hex_colors): color_map[therapy] = color
321 return color_map
322
323
324 def plot_sankey_robust(flow_df: pd.DataFrame, therapy_colors: dict, title: str):
325 """Generates a Sankey diagram using a structural method that prevents backward links."""
326 source_nodes = flow_df[['day_from', 'item_from']] \
327 .rename(columns={'day_from': 'day', 'item_from': 'label'})
328 target_nodes = flow_df[['day_to', 'item_to']] \
329 .rename(columns={'day_to': 'day', 'item_to': 'label'})
330 nodes_df = pd.concat([source_nodes, target_nodes]) \
331 .drop_duplicates().sort_values(['day', 'label']) \
332 .reset_index(drop=True)
333 nodes_df['id'] = nodes_df.index
334 node_map = pd.Series(nodes_df.id.values, index=nodes_df['label'] + '_day_' + nodes_df['day'].astype(str))
335 flow_df['source_id'] = (flow_df['item_from'] + '_day_' + flow_df['day_from'].astype(str)).map(node_map)
336 flow_df['target_id'] = (flow_df['item_to'] + '_day_' + flow_df['day_to'].astype(str)).map(node_map)
337 unique_days = sorted(nodes_df['day'].unique())
338 day_x_map = {day: i / (len(unique_days) - 1) if len(unique_days) > 1 else 0.5 for i, day in enumerate(unique_days)}
339 node_x = nodes_df['day'].map(day_x_map)
340 node_colors = nodes_df['label'].map(therapy_colors).fillna('#CCCCCC')
341 link_colors = flow_df['item_from'].map(therapy_colors).fillna('#CCCCCC') \
342 .apply(lambda h: f"rgba({int(h[1:3],16)},{int(h[3:5],16)},{int(h[5:7],16)},0.4)")
343 fig = go.Figure(data=[go.Sankey(
344 arrangement="snap",
345 node=dict(pad=20, thickness=25, line=dict(color="black", width=0.5),
346 label=nodes_df['label'], color=node_colors, x=node_x),
347 link=dict(source=flow_df['source_id'], target=flow_df['target_id'],
348 value=flow_df['count'], color=link_colors)
349 )])
350 y = -0.05 if TERMINAL else -0.17
351 annotations = [
352 dict(x=x, y=y, text=f"<b>Day {d}</b>", showarrow=False,
353 font=dict(size=14), xref="paper", yref="paper", xanchor="center")
354 for d, x in day_x_map.items()]
355 fig.update_layout(
356 title_text=f"<b>{title}</b>",
357 title_x=0.5,
358 font=dict(size=12),
359 annotations=annotations,
360 margin=dict(b=100)
361 )
362
363 if not TERMINAL:
364 fig.update_layout(margin=dict(l=30, r=30, t=60, b=100))
365
366 return fig
367
368
369
370
371
372 # =============================================================================
373 # MAIN
374 # =============================================================================
375 if __name__ == '__main__':
376
377 # .. note:: To ensure the logic is correct, you can add a test case. Create
378 # and filter synthetic patient with a specific therapy progression
379 # over several days. Then, run the script and confirm that the final
380 # plot accurately reflects this known pathway.
381
382 # Constants
383 N_PATIENTS = 200 # Number of patients
384 MAX_LOS = 6 # Maximum length of stay.
385 TOP_N_THERAPIES = 10 # Filter by top most common therapies
386 ANALYSIS_LEVEL = 'aware' # Options ('drug', 'class' and 'aware')
387
388 # Define variables
389 therapy_col = COLUMN_FOR_ANALYSIS[ANALYSIS_LEVEL]
390 subject_col = 'SUBJECT'
391 day_col = 'DAY_NUM'
392
393 # Generate synthetic data.
394 patient_df = create_synthetic_data(num_patients=N_PATIENTS, max_los=MAX_LOS)
395 #patient_df = create_validation_data()
396
397 # Pre-process patient data
398 processed_df = process_patient_data(df=patient_df,
399 level=ANALYSIS_LEVEL,
400 therapy_col=therapy_col,
401 subject_col=subject_col,
402 day_col=day_col)
403
404 # Filter top n therapies
405 filtered_df = limit_therapies_to_top_n(df=processed_df,
406 therapy_col=therapy_col,
407 n=TOP_N_THERAPIES)
408
409 # Generate flow data
410 flow_data = create_flow_from_patient_data_pivot(df=filtered_df,
411 therapy_col=therapy_col,
412 subject_col=subject_col,
413 day_col=day_col)
414
415 # Display
416 if not flow_data.empty:
417
418 # .. note:: Create a custom color scheme to make information easier to understand.
419 # For example, you could use green, yellow, and red for statuses like
420 # 'access,' 'watch,' and 'reserve.' You can also assign a unique color
421 # to each drug class and then use different shades of that color for
422 # the individual drugs within the class. For help with this, you can
423 # ask Gemini to generate the color palettes for you
424
425 colors = create_comprehensive_color_map(flow_data)
426 plot_title = f"Antimicrobial Therapy Transitions ({ANALYSIS_LEVEL.title()} View)"
427 if TOP_N_THERAPIES: plot_title += f" - Top {TOP_N_THERAPIES}"
428 fig = plot_sankey_robust(flow_data, colors, title=plot_title)
429
430 # Show
431 from plotly.io import show
432 show(fig)
433 else:
434 print("No flow data to plot. The dataset might be too small or filtered.")
Total running time of the script: ( 0 minutes 1.277 seconds)