Note
Click here to download the full example code
06. Plot Treemap with MIMIC
This example displays a Treemap using the MIMIC dataset.
Warning
It is not completed!
Out:
Data:
ids labels parents
0 Coffee Coffee Flavors NaN
1 Aromas Aromas Coffee
2 Tastes Tastes Coffee
3 Aromas-Enzymatic Enzymatic Aromas
4 Aromas-Sugar Browning Sugar Browning Aromas
.. ... ... ...
92 Pungent-Thyme Thyme Spicy-Pungent
93 Smokey-Tarry Tarry Carbony-Smokey
94 Smokey-Pipe Tobacco Pipe Tobacco Carbony-Smokey
95 Ashy-Burnt Burnt Carbony-Ashy
96 Ashy-Charred Charred Carbony-Ashy
[97 rows x 3 columns]
DF:
ids labels parents
0 Coffee Coffee Flavors NaN
1 Aromas Aromas Coffee
2 Tastes Tastes Coffee
3 Aromas-Enzymatic Enzymatic Aromas
4 Aromas-Sugar Browning Sugar Browning Aromas
.. ... ... ...
92 Pungent-Thyme Thyme Spicy-Pungent
93 Smokey-Tarry Tarry Carbony-Smokey
94 Smokey-Pipe Tobacco Pipe Tobacco Carbony-Smokey
95 Ashy-Burnt Burnt Carbony-Ashy
96 Ashy-Charred Charred Carbony-Ashy
[97 rows x 3 columns]
10 # Libraries
11 import pandas as pd
12
13
14 from plotly.io import show
15
16 try:
17 __file__
18 TERMINAL = True
19 except:
20 TERMINAL = False
21
22
23 # ---------------------
24 # Helper method
25 # ---------------------
26
27 # Methods
28 def build_hierarchical_dataframe(df, levels, value_column, color_columns=None):
29 """
30 Build a hierarchy of levels for Sunburst or Treemap charts.
31
32 Levels are given starting from the bottom to the top of the hierarchy,
33 ie the last level corresponds to the root.
34 """
35 df_all_trees = pd.DataFrame(columns=['id', 'parent', 'value', 'color'])
36 for i, level in enumerate(levels):
37 df_tree = pd.DataFrame(columns=['id', 'parent', 'value', 'color'])
38 dfg = df.groupby(levels[i:]).sum()
39 dfg = dfg.reset_index()
40 df_tree['id'] = dfg[level].copy()
41 if i < len(levels) - 1:
42 df_tree['parent'] = dfg[levels[i+1]].copy()
43 else:
44 df_tree['parent'] = 'total'
45 df_tree['value'] = dfg[value_column]
46 #df_tree['color'] = dfg[color_columns[0]] / dfg[color_columns[1]]
47 df_all_trees = df_all_trees.append(df_tree, ignore_index=True)
48 #total = pd.Series(dict(id='total', parent='',
49 # value=df[value_column].sum(),
50 # color=df[color_columns[0]].sum() / df[color_columns[1]].sum()))
51 total = pd.Series(dict(id='total', parent='', value=df[value_column].sum()))
52 df_all_trees = df_all_trees.append(total, ignore_index=True)
53 return df_all_trees
54
55 def load_sunburst():
56 """Load sunburst data."""
57 # Define URL
58 URL = 'https://raw.githubusercontent.com/plotly/'
59 URL+= 'datasets/96c0bd/sunburst-coffee-flavors-complete.csv'
60 # Load dataframe
61 df = pd.read_csv(URL)
62 # Return
63 return df
64
65 def load_microbiology_nhs(n=10000):
66 """Loads and formats microbiology data."""
67 # Libraries
68 from pyamr.core.sari import sari
69 from pyamr.datasets.load import load_data_nhs
70
71 # Load data
72 data, antimicrobials, microorganisms = \
73 load_data_nhs(nrows=n)
74
75 data = data[data.specimen_code.isin(['BLOOD CULTURE'])]
76
77 # Create DataFrame
78 dataframe = data.groupby(['specimen_code',
79 'microorganism_code',
80 'antimicrobial_code',
81 'sensitivity']) \
82 .size().unstack().fillna(0)
83
84
85 # Compute frequency
86 dataframe['freq'] = dataframe.sum(axis=1)
87 dataframe['sari'] = sari(dataframe, strategy='hard')
88 dataframe['sari_medium'] = sari(dataframe, strategy='medium')
89 dataframe['sari_soft'] = sari(dataframe, strategy='soft')
90 dataframe = dataframe.reset_index()
91
92 # Add info for popup (micro and abxs)
93 dataframe = dataframe.merge(antimicrobials,
94 how='left', left_on='antimicrobial_code',
95 right_on='antimicrobial_code')
96 dataframe = dataframe.merge(microorganisms,
97 how='left', left_on='microorganism_code',
98 right_on='microorganism_code')
99
100 # Format dataframe
101 dataframe = dataframe.round(decimals=3)
102
103 # Configuration
104 LEVELS = ['specimen_code', 'microorganism_code', 'antimicrobial_code']
105 COLORS = ['sari']
106 VALUE = 'freq'
107
108 dataframe = dataframe[LEVELS + COLORS + [VALUE]]
109
110 aux2 = dataframe.groupby(LEVELS).agg('sum').reset_index()
111
112
113 # Return
114 aux = build_hierarchical_dataframe(aux2, LEVELS, COLORS, VALUE)
115
116 return aux
117
118
119 # -----------------------------------
120 # Display basic
121 # -----------------------------------
122 # Libraries
123 import plotly.graph_objects as go
124
125 # Load data
126 df = load_sunburst()
127
128 # Show data
129 print("\nData:")
130 print(df)
131
132 # Define template
133 htmp = '<b>%{label}</b><br>'
134 htmp+= 'Sales:%{value}<br>'
135 htmp+= 'Success rate: %{color:.2f}'
136
137 # Create figure
138 fig = go.Figure(go.Treemap(
139 ids=df.ids,
140 labels=df.labels,
141 parents=df.parents,
142 pathbar_textfont_size=15,
143 root_color="lightgrey",
144 #maxdepth=3,
145 branchvalues='total',
146 #marker=dict(
147 # colors=df_all_trees['color'],
148 # colorscale='RdBu',
149 # cmid=average_score),
150 #hovertemplate=htmp,
151 #marker_colorscale='Blues'
152 ))
153
154 # Update layout
155 fig.update_layout(
156 uniformtext=dict(minsize=10, mode='hide'),
157 margin = dict(t=50, l=25, r=25, b=25)
158 )
159
160
161 # -----------------------------------
162 # Display NHS
163 # -----------------------------------
164 # Load data
165 #df = load_microbiology_nhs()
166 #df = df.drop_duplicates(subset=['id', 'parent'])
167
168 # Show
169 print("DF:")
170 print(df)
171
172 # Create data
173 df = pd.DataFrame()
174 df['id'] = ['BLD', 'SAUR', 'PENI', 'CIPRO']
175 df['parent'] = [None, 'BLD', 'SAUR', 'SAUR']
176 df['value'] = [0, 1, 2, 3]
177 df['info'] = ['info1', 'info2', 'info3', 'info4']
178
179 # Define template
180 htmp = '<b> %{id} </b><br>'
181 htmp+= 'Info: %{info} <br>'
182 htmp+= 'Value %{value}'
183
184 # Create figure
185 fig = go.Figure(go.Treemap(
186 #ids=df.id,
187 labels=df.id,
188 values=df.value,
189 parents=df.parent,
190 pathbar_textfont_size=15,
191 root_color="lightgrey",
192 #maxdepth=3,
193 branchvalues='total',
194 #marker=dict(
195 # colors=df_all_trees['color'],
196 # colorscale='RdBu',
197 # cmid=average_score),
198 #hovertemplate=htmp,
199 #marker_colorscale='Blues'
200 #hovertemplate=htmp, # overrides hoverinfo
201 #texttemplate=htmp, # overrides textinfo
202 #hoverinfo=['all'],
203 #textinfo=['all']
204 textinfo="label+value",
205 ))
206
207 # Update layout
208 fig.update_layout(
209 uniformtext=dict(minsize=10, mode='hide'),
210 margin = dict(t=50, l=25, r=25, b=25)
211 )
212
213 """
214 # Add traces.
215 fig.add_trace(go.Treemap(
216 #labels=df_trees.id,
217 #parents=df_trees.parent,
218 #values=df_trees.value,
219 #branchvalues='total',
220 labels=df_trees.ids,
221 parents=df_trees.parents,
222 values=df_trees.values,
223 #marker=dict(
224 # colors=df_all_trees['color'],
225 # colorscale='RdBu',
226 # cmid=average_score),
227 #hovertemplate='<b>%{label} </b> <br> Sales: %{value}<br> Success rate: %{color:.2f}',
228 #name=''
229 ), 1, 1)
230
231 # Update layout
232 #fig.update_layout(margin = dict(t=50, l=25, r=25, b=25))
233 """
234
235 # Show
236 show(fig)
Total running time of the script: ( 0 minutes 0.133 seconds)