Note
Click here to download the full example code
05.c sns.clustermap
multiple categories
Plot a matrix dataset as a hierarchically-clustered heatmap.
Note
The hierarchical clustering has been deactivated.
10 # Libraries
11 import pandas as pd
12 import seaborn as sns
13 import matplotlib.pyplot as plt
14
15 from matplotlib.pyplot import gcf
Let’s load the dataset
20 # Load dataset
21 networks = sns.load_dataset("brain_networks",
22 index_col=0, header=[0, 1, 2])
Let’s create the network colors
27 # Create network colors
28 network_labels = networks.columns.get_level_values("network")
29 network_pal = sns.cubehelix_palette(network_labels.unique().size,
30 light=.9, dark=.1, reverse=True, start=1, rot=-2)
31 network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
32
33 network_colors = \
34 pd.Series(network_labels, index=networks.columns) \
35 .map(network_lut)
Let’s create the node colors
40 # Create node colors
41 node_labels = networks.columns.get_level_values("node")
42 node_pal = sns.cubehelix_palette(node_labels.unique().size)
43 node_lut = dict(zip(map(str, node_labels.unique()), node_pal))
44
45 node_colors = \
46 pd.Series(node_labels, index=networks.columns) \
47 .map(node_lut)
Let’s combine them.
52 # Combine
53 network_node_colors = \
54 pd.DataFrame(network_colors) \
55 .join(pd.DataFrame(node_colors))
Let’s display the clustermap
60 # Display
61 g = sns.clustermap(networks.corr(),
62 row_cluster=False, col_cluster=False, # turn off clusters
63 row_colors = network_node_colors, # add colored labels
64 col_colors = network_node_colors, # add colored labels
65 linewidths=0, xticklabels=False, yticklabels=False,
66 center=0, cmap="vlag", figsize=(7, 7))
67
68
69 # Add legend for networks
70 for label in network_labels.unique():
71 g.ax_col_dendrogram.bar(0, 0,
72 color=network_lut[label], label=label, linewidth=0)
73
74 l1 = g.ax_col_dendrogram.legend(title='Network',
75 loc="center", ncol=5, bbox_to_anchor=(0.53, 0.9),
76 bbox_transform=gcf().transFigure)
77
78 # Add legend for nodes
79 for label in node_labels.unique():
80 g.ax_row_dendrogram.bar(0, 0,
81 color=node_lut[label], label=label, linewidth=0)
82
83 l2 = g.ax_row_dendrogram.legend(title='Node',
84 loc="center", ncol=1, bbox_to_anchor=(0.86, 0.9),
85 bbox_transform=gcf().transFigure)
86
87 plt.show()
Total running time of the script: ( 0 minutes 0.405 seconds)