08. Patient therapy flow (sankey)

This Sankey diagram visualizes treatment pathways by showing the flow of therapies between consecutive days. The width of each path is proportional to the volume of these transitions, which should be interpreted based on the input data’s structure. If the data contains a single entry for a given therapy per subject per day, the flow represents the number of patients. However, if the data includes multiple entries for the same therapy on a given day, such as for each dose, then the flow represents the total number of applications.

It should be noted that the visualization’s level of detail can be modified directly in the code. This allows the trajectories to be displayed by specific drug, broader drug category, or the WHO AWaRe classification (Access, Watch, Reserve).

 19 # Libraries
 20 import pandas as pd
 21 import numpy as np
 22 import plotly.graph_objects as go
 23 import matplotlib.cm as cm
 24
 25 try:
 26     __file__
 27     TERMINAL = True
 28 except:
 29     TERMINAL = False
 30
 31 # =============================================================================
 32 # LOOKUP TABLES & CONSTANTS
 33 # =============================================================================
 34
 35 def drug_to_aware_default():
 36     """Create default drug to aware classes."""
 37     return {
 38         "amikacin": "Access", "amoxicillin": "Access", "co-amoxiclav": "Access",
 39         "benzylpenicillin sodium": "Access", "cefalexin": "Access", "cefazolin": "Access",
 40         "chloramphenicol": "Access", "clindamycin": "Access", "doxycycline": "Access",
 41         "flucloxacillin": "Access", "gentamicin": "Access", "mecillinam": "Access",
 42         "metronidazole": "Access", "nitrofurantoin": "Access", "pivmecillinam": "Access",
 43         "co-trimoxazole": "Access", "trimethoprim": "Access", "azithromycin": "Watch",
 44         "cefepime": "Watch", "ceftazidime": "Watch", "ceftriaxone": "Watch",
 45         "cefuroxime": "Watch", "ciprofloxacin": "Watch", "clarithromycin": "Watch",
 46         "ertapenem": "Watch", "erythromycin": "Watch", "fosfomycin": "Watch",
 47         "levofloxacin": "Watch", "meropenem": "Watch", "moxifloxacin": "Watch",
 48         "piperacillin-tazobactam": "Watch", "teicoplanin": "Watch", "temocillin": "Watch",
 49         "vancomycin": "Watch", "aztreonam": "Reserve", "cefiderocol": "Reserve",
 50         "ceftazidime/avibactam": "Reserve", "ceftolozane/tazobactam": "Reserve",
 51         "colistin": "Reserve", "daptomycin": "Reserve", "linezolid": "Reserve",
 52         "tazocin": "Watch",
 53         'no therapy': 'No Therapy'
 54     }
 55
 56 def drug_to_aware_class_winnie():
 57     """Create drug to aware class as defined by winnie"""
 58     ACCESS_LIST = [
 59         "Amikacin", "Amoxicillin", "Amoxicillin (contains penicillin)",
 60         "Amoxicillin/clavulanic-acid", "Amoxiclavulanic-acid",
 61         "Benzylpenicillin sodium", "Co-amoxiclav", "Co-amoxiclav (contains penicillin)",
 62         "Co-amoxiclav (contains penicillin) (anes)",
 63         "Co-amoxiclav 400mg/57mg in 5ml oral suspension (contains penicillin)",
 64         "co-amoxiclav 250mg/62mg in 5ml oral suspension (contains penicillin)",
 65         "Cefalexin", "Cefazolin", "Chloramphenicol", "Clindamycin", "Doxycycline",
 66         "Flucloxacillin (contains penicillin)", "Gentamicin", "Gentamicin (anes)",
 67         "Mecillinam", "Metronidazole", "Metronidazole (anes)", "Metronidazole_IV",
 68         "Metronidazole_oral", "Nitrofurantoin", "Co-amoxiclav in suspension",
 69         "Pivmecillinam", "Pivmecillinam (contains penicillin)",
 70         "Sulfamethoxazole/trimethoprim", "Trimethoprim", "co-trimoxazole"
 71     ]
 72     WATCH_LIST = [
 73         "Azithromycin", "Cefepime", "Ceftazidime", "Ceftriaxone", "Cefuroxime",
 74         "Cefuroxime (anes)", "Ciprofloxacin", "Clarithromycin", "Ertapenem",
 75         "Erythromycin", "Fosfomycin", "Fosfomycin_oral", "Levofloxacin", "Meropenem",
 76         "Moxifloxacin", "Piperacillin/tazobactam", "piperacillin-tazobactam (contains penicillin)",
 77         "piperacillin + tazobactam (contains penicillin)", "Teicoplanin", "Teicoplanin (anes)",
 78         "Temocillin", "Temocillin (contains penicillin)", "Vancomycin",
 79         "Vancomycin (anes)", "Vancomycin (anes) 1g", "Vancomycin_IV", "Vancomycin_oral"
 80     ]
 81     RESERVE_LIST = [
 82         "Aztreonam", "Cefiderocol", "Ceftazidime/avibactam", "Ceftolozane/tazobactam",
 83         "Colistin", "Colistin_IV", "Dalbavancin", "Daptomycin", "Fosfomycin_IV",
 84         "Iclaprim", "Linezolid", "Tigecycline"
 85     ]
 86
 87     # Define lookup table as this is more efficient for lookup than searching
 88     # through lists every time. We convert all drug names to lowercase to
 89     # ensure case-insensitive matching.
 90     d = {drug.lower(): "Access" for drug in ACCESS_LIST}
 91     d.update({drug.lower(): "Watch" for drug in WATCH_LIST})
 92     d.update({drug.lower(): "Reserve" for drug in RESERVE_LIST})
 93     d.update({'No therapy'.lower(): 'No Therapy'})
 94     return d
 95
 96 DRUG_TO_AWARE_CLASS = drug_to_aware_default()
 97
 98 # .. note:: Please be aware that this list is not exhaustive. Any antibiotics not assigned
 99 #           to a specific category are currently labeled as 'Unknown'. To ensure complete
100 #           coverage, please expand the list to encompass all items.
101 # .. note:: Ensure keys are lowercase.
102
103 DRUG_TO_CLASS = {
104     "amoxicillin": "Penicillin", "flucloxacillin": "Penicillin", "pivmecillinam": "Penicillin",
105     "co-amoxiclav": "Penicillin and beta-lactamase inhibitor", "tazocin": "Penicillin and beta-lactamase inhibitor",
106     "piperacillin-tazobactam": "Penicillin and beta-lactamase inhibitor", "cefuroxime": "2nd Gen Cephalosporin",
107     "ceftriaxone": "3rd Gen Cephalosporin", "ceftazidime": "3rd Gen Cephalosporin",
108     "cefepime": "4th Gen Cephalosporin", "meropenem": "Carbapenem", "aztreonam": "Monobactam",
109     "gentamicin": "Aminoglycoside", "amikacin": "Aminoglycoside", "ciprofloxacin": "Fluoroquinolone",
110     "clarithromycin": "Macrolide", "erythromycin": "Macrolide", "vancomycin": "Glycopeptide",
111     "teicoplanin": "Glycopeptide", "linezolid": "Oxazolidinone", "metronidazole": "Nitroimidazole",
112     "doxycycline": "Tetracycline", "fosfomycin": "Phosphonic acid", "co-trimoxazole": "Folate pathway inhibitor",
113     "chloramphenicol": "Amphenicol", "temocillin": "Penicillin",
114     "no therapy": "No Therapy" # To keep no therapy (otherwise will be set to Unkown)
115 }
116
117 COLUMN_FOR_ANALYSIS = {
118     'drug': 'drug_norm',
119     'class': 'drug_class',
120     'aware': 'aware_class'
121 }
122
123
124 # =============================================================================
125 # HELPER FUNCTIONS
126 # =============================================================================
127
128 def create_validation_data():
129     """
130     Creates a small, specific dataset with two patients to visually validate Sankey logic.
131     - Tests single therapy -> combination therapy transition.
132     - Tests merging flows from different sources into one combination node.
133     - Tests combination therapy -> single therapy transition.
134     - Tests transitions to 'No Therapy' and terminating journeys.
135     """
136     records = [
137         # --- Patient 101: Longer journey with a combo ---
138         {'SUBJECT': 101, 'DAY_NUM': 1, 'drug_norm': 'amoxicillin'},
139         # Day 2 is a combination therapy
140         {'SUBJECT': 101, 'DAY_NUM': 2, 'drug_norm': 'amoxicillin'},
141         {'SUBJECT': 101, 'DAY_NUM': 2, 'drug_norm': 'linezolid'},
142         # Day 3 transitions out of the combo
143         {'SUBJECT': 101, 'DAY_NUM': 3, 'drug_norm': 'vancomycin'},
144         {'SUBJECT': 101, 'DAY_NUM': 4, 'drug_norm': 'No Therapy'},
145         {'SUBJECT': 101, 'DAY_NUM': 5, 'drug_norm': 'No Therapy'},
146         {'SUBJECT': 101, 'DAY_NUM': 7, 'drug_norm': 'vancomycin'},
147         {'SUBJECT': 101, 'DAY_NUM': 8, 'drug_norm': 'vancomycin'},
148
149         # --- Patient 102: Shorter journey, merges into the combo on Day 2 ---
150         {'SUBJECT': 102, 'DAY_NUM': 1, 'drug_norm': 'ciprofloxacin'},
151         # Day 2 is the same combination therapy as Patient 101
152         {'SUBJECT': 102, 'DAY_NUM': 2, 'drug_norm': 'amoxicillin'},
153         {'SUBJECT': 102, 'DAY_NUM': 2, 'drug_norm': 'linezolid'},
154         # Day 3 transitions to the same drug, but journey ends here
155         {'SUBJECT': 102, 'DAY_NUM': 3, 'drug_norm': 'vancomycin'},
156     ]
157     return pd.DataFrame(records)
158
159
160 def remove_days_randomly(df, missing_day_fraction=0.01):
161     """Remove some days randomly."""
162     if missing_day_fraction > 0:
163         # Identify the indices of rows that are NOT the first day for any patient.
164         droppable_indices = df[df['DAY_NUM'] != 1].index
165         # Calculate how many of these rows to drop
166         num_to_drop = int(len(droppable_indices) * missing_day_fraction)
167         # Randomly choose the indices to drop from the droppable list
168         indices_to_drop = np.random.choice(droppable_indices, size=num_to_drop, replace=False)
169         # Drop the selected rows to create the final DataFrame
170         df = df.drop(indices_to_drop)
171     return df
172
173
174 def create_synthetic_data(num_patients=50, max_los=7):
175     """
176     Creates synthetic patient antibiotic data with variable lengths of stay.
177
178     .. note:: It adds combined therapies to first 5.
179     .. note:: It removes some days with no therapy information.
180
181     Args:
182         num_patients: The number of patients to simulate.
183         max_los: The maximum possible length of stay for any patient.
184     """
185     antibiotics = list(DRUG_TO_AWARE_CLASS.keys()) # + ['No Therapy']
186     records = []
187     for pid in range(num_patients):
188         los = np.random.randint(2, max_los + 1)
189         current_therapy = np.random.choice(antibiotics)
190         records.append({'SUBJECT': pid, 'DAY_NUM': 1, 'drug_norm': current_therapy})
191         for day in range(2, los + 1):
192             if current_therapy == 'No Therapy':
193                 new_therapy = np.random.choice(antibiotics,
194                     p=[0.8] + [0.2 / (len(antibiotics)-1)] * (len(antibiotics)-1))
195             else:
196                 other_abs = [ab for ab in antibiotics if ab != current_therapy]
197                 new_therapy = np.random.choice([current_therapy, 'No Therapy'] + other_abs,
198                     p=[0.6, 0.1] + [0.3/len(other_abs)]*len(other_abs))
199             records.append({'SUBJECT': pid, 'DAY_NUM': day, 'drug_norm': new_therapy})
200             current_therapy = new_therapy
201     # Add combination therapies for testing.
202     for pid in range(5):
203         records.append({'SUBJECT': pid, 'DAY_NUM': 2, 'drug_norm': 'amoxicillin'})
204         records.append({'SUBJECT': pid, 'DAY_NUM': 2, 'drug_norm': 'linezolid'})
205     # Remove days randomly for testing
206     df = remove_days_randomly(pd.DataFrame(records), missing_day_fraction=0.1)
207     # Return
208     return df
209
210
211 def find_missing_days(group: pd.DataFrame,
212                       subject_col: str,
213                       day_col: str):
214     """For a subject's data, return a DataFrame of missing days."""
215     min_day, max_day = group[day_col].min(), group[day_col].max()
216
217     # Use sets for a fast way to find what's missing
218     all_possible_days = set(range(min_day, max_day + 1))
219     existing_days = set(group[day_col])
220     missing_days = sorted(list(all_possible_days - existing_days))
221
222     # Return a new DataFrame containing just the missing rows
223     if missing_days:
224         return pd.DataFrame({
225             subject_col: group.name,
226             day_col: missing_days,
227             'drug_norm': 'No Therapy'
228         })
229     return None  # Return nothing if no days are missing
230
231
232 def process_patient_data(df: pd.DataFrame,
233                          level: str,
234                          therapy_col: str,
235                          subject_col: str,
236                          day_col: str) -> pd.DataFrame:
237     """
238     Applies classifications and correctly aggregates drug combinations for all
239     analysis levels to make the data "pivot-ready".
240     """
241     df_copy = df.copy()
242
243     # Create data with missing days and 'No Therapy'
244     df_missing = df_copy.groupby(subject_col) \
245         .apply(find_missing_days, subject_col=subject_col, day_col=day_col) \
246         .reset_index(drop=True)
247
248     # Concatenate original data and missing data
249     df_copy = pd.concat([df_copy, df_missing]).reset_index(drop=True)
250
251     if level == 'class':
252         df_copy[therapy_col] = df_copy['drug_norm'].str.lower() \
253             .map(DRUG_TO_CLASS).fillna('Unknown')
254         # Aggregate classes for combination therapies into a single sorted string
255         return df_copy.groupby([subject_col, day_col])[therapy_col] \
256             .apply(lambda x: ' + '.join(sorted(x.unique()))) \
257             .reset_index()
258
259     elif level == 'aware':
260         df_copy[therapy_col] = df_copy['drug_norm'].str.lower() \
261             .map(DRUG_TO_AWARE_CLASS).fillna('Unknown')
262         # For combinations, resolve by picking the highest-order class
263         aware_order = {'Access': 0, 'Watch': 1, 'Reserve': 2, 'Unknown': -1, 'No Therapy': -2}
264         df_copy['aware_ordinal'] = df_copy[therapy_col].map(aware_order)
265         idx = df_copy.groupby([subject_col, day_col])['aware_ordinal'].idxmax()
266         return df_copy.loc[idx].drop(columns='aware_ordinal')
267
268     elif level == 'drug':
269         # Aggregates combination drugs into a single sorted string
270         return df_copy.groupby([subject_col, day_col])[therapy_col]\
271             .apply(lambda x: ' + '.join(sorted(x.unique()))) \
272             .reset_index()
273
274     return df_copy
275
276
277 def limit_therapies_to_top_n(df: pd.DataFrame, therapy_col: str, n: int) -> pd.DataFrame:
278     """Keeps the top N most frequent therapies and groups the rest into 'Other'."""
279     if n is None: return df
280     top_n = df[therapy_col].value_counts().nlargest(n).index.tolist()
281     if 'No Therapy' not in top_n:
282         top_n.append('No Therapy')
283     df[therapy_col] = df[therapy_col].where(df[therapy_col].isin(top_n), 'Other')
284     return df
285
286
287 def create_flow_from_patient_data_pivot(df: pd.DataFrame,
288                                         therapy_col: str,
289                                         subject_col: str,
290                                         day_col: str) -> pd.DataFrame:
291     """Transforms raw patient data into a Sankey-ready flow DataFrame."""
292     patient_journeys = df.pivot(index=subject_col, columns=day_col, values=therapy_col)
293     all_links = []
294     for i in range(1, patient_journeys.columns.max()):
295         day_from, day_to = i, i + 1
296         if day_from in patient_journeys.columns and day_to in patient_journeys.columns:
297             transition_df = patient_journeys[[day_from, day_to]].dropna()
298             links = transition_df.value_counts().reset_index(name='count')
299             links.rename(columns={day_from: 'item_from', day_to: 'item_to'}, inplace=True)
300             links['day_from'], links['day_to'] = day_from, day_to
301             all_links.append(links)
302     if not all_links:
303         return pd.DataFrame(columns=['day_from', 'item_from', 'day_to', 'item_to', 'count'])
304     return pd.concat(all_links, ignore_index=True)
305
306
307 def create_comprehensive_color_map(df: pd.DataFrame) -> dict:
308     """Generates a consistent color map, including special colors for specified categories."""
309     therapies = sorted(list(set(np.concatenate([df['item_from'].unique(), df['item_to'].unique()]))))
310     color_map = {}
311     special_colors = {'No Therapy': '#aaaaaa', 'Other': '#d3d3d3', 'Unknown': '#e5e5e5'}
312     for category, color in special_colors.items():
313         if category in therapies:
314             color_map[category] = color
315             therapies.remove(category)
316     if not therapies: return color_map
317     colormap = cm.get_cmap('tab20b', len(therapies))
318     hex_colors = ['#%02x%02x%02x' % (int(r*255), int(g*255), int(b*255)) \
319         for r, g, b, a in colormap(np.linspace(0, 1, len(therapies)))]
320     for therapy, color in zip(therapies, hex_colors): color_map[therapy] = color
321     return color_map
322
323
324 def plot_sankey_robust(flow_df: pd.DataFrame, therapy_colors: dict, title: str):
325     """Generates a Sankey diagram using a structural method that prevents backward links."""
326     source_nodes = flow_df[['day_from', 'item_from']] \
327         .rename(columns={'day_from': 'day', 'item_from': 'label'})
328     target_nodes = flow_df[['day_to', 'item_to']] \
329         .rename(columns={'day_to': 'day', 'item_to': 'label'})
330     nodes_df = pd.concat([source_nodes, target_nodes]) \
331         .drop_duplicates().sort_values(['day', 'label']) \
332         .reset_index(drop=True)
333     nodes_df['id'] = nodes_df.index
334     node_map = pd.Series(nodes_df.id.values, index=nodes_df['label'] + '_day_' + nodes_df['day'].astype(str))
335     flow_df['source_id'] = (flow_df['item_from'] + '_day_' + flow_df['day_from'].astype(str)).map(node_map)
336     flow_df['target_id'] = (flow_df['item_to'] + '_day_' + flow_df['day_to'].astype(str)).map(node_map)
337     unique_days = sorted(nodes_df['day'].unique())
338     day_x_map = {day: i / (len(unique_days) - 1) if len(unique_days) > 1 else 0.5 for i, day in enumerate(unique_days)}
339     node_x = nodes_df['day'].map(day_x_map)
340     node_colors = nodes_df['label'].map(therapy_colors).fillna('#CCCCCC')
341     link_colors = flow_df['item_from'].map(therapy_colors).fillna('#CCCCCC') \
342         .apply(lambda h: f"rgba({int(h[1:3],16)},{int(h[3:5],16)},{int(h[5:7],16)},0.4)")
343     fig = go.Figure(data=[go.Sankey(
344         arrangement="snap",
345         node=dict(pad=20, thickness=25, line=dict(color="black", width=0.5),
346                   label=nodes_df['label'], color=node_colors, x=node_x),
347         link=dict(source=flow_df['source_id'], target=flow_df['target_id'],
348                   value=flow_df['count'], color=link_colors)
349     )])
350     y = -0.05 if TERMINAL else -0.17
351     annotations = [
352         dict(x=x, y=y, text=f"<b>Day {d}</b>", showarrow=False,
353             font=dict(size=14), xref="paper", yref="paper", xanchor="center")
354                 for d, x in day_x_map.items()]
355     fig.update_layout(
356         title_text=f"<b>{title}</b>",
357         title_x=0.5,
358         font=dict(size=12),
359         annotations=annotations,
360         margin=dict(b=100)
361     )
362
363     if not TERMINAL:
364         fig.update_layout(margin=dict(l=30, r=30, t=60, b=100))
365
366     return fig
367
368
369
370
371
372 # =============================================================================
373 #                                   MAIN
374 # =============================================================================
375 if __name__ == '__main__':
376
377     # .. note:: To ensure the logic is correct, you can add a test case. Create
378     #           and filter synthetic patient with a specific therapy progression
379     #           over several days. Then, run the script and confirm that the final
380     #           plot accurately reflects this known pathway.
381
382     # Constants
383     N_PATIENTS = 200          # Number of patients
384     MAX_LOS = 6                 # Maximum length of stay.
385     TOP_N_THERAPIES = 10       # Filter by top most common therapies
386     ANALYSIS_LEVEL = 'aware'   # Options ('drug', 'class' and 'aware')
387
388     # Define variables
389     therapy_col = COLUMN_FOR_ANALYSIS[ANALYSIS_LEVEL]
390     subject_col = 'SUBJECT'
391     day_col = 'DAY_NUM'
392
393     # Generate synthetic data.
394     patient_df = create_synthetic_data(num_patients=N_PATIENTS, max_los=MAX_LOS)
395     #patient_df = create_validation_data()
396
397     # Pre-process patient data
398     processed_df = process_patient_data(df=patient_df,
399                                         level=ANALYSIS_LEVEL,
400                                         therapy_col=therapy_col,
401                                         subject_col=subject_col,
402                                         day_col=day_col)
403
404     # Filter top n therapies
405     filtered_df = limit_therapies_to_top_n(df=processed_df,
406                                            therapy_col=therapy_col,
407                                            n=TOP_N_THERAPIES)
408
409     # Generate flow data
410     flow_data = create_flow_from_patient_data_pivot(df=filtered_df,
411                                                     therapy_col=therapy_col,
412                                                     subject_col=subject_col,
413                                                     day_col=day_col)
414
415     # Display
416     if not flow_data.empty:
417
418         # .. note:: Create a custom color scheme to make information easier to understand.
419         #           For example, you could use green, yellow, and red for statuses like
420         #           'access,' 'watch,' and 'reserve.' You can also assign a unique color
421         #           to each drug class and then use different shades of that color for
422         #           the individual drugs within the class. For help with this, you can
423         #           ask Gemini to generate the color palettes for you
424
425         colors = create_comprehensive_color_map(flow_data)
426         plot_title = f"Antimicrobial Therapy Transitions ({ANALYSIS_LEVEL.title()} View)"
427         if TOP_N_THERAPIES: plot_title += f" - Top {TOP_N_THERAPIES}"
428         fig = plot_sankey_robust(flow_data, colors, title=plot_title)
429
430         # Show
431         from plotly.io import show
432         show(fig)
433     else:
434         print("No flow data to plot. The dataset might be too small or filtered.")

Total running time of the script: ( 0 minutes 1.277 seconds)

Gallery generated by Sphinx-Gallery