05.c sns.clustermap multiple categories

Plot a matrix dataset as a hierarchically-clustered heatmap.

Note

The hierarchical clustering has been deactivated.

10 # Libraries
11 import pandas as pd
12 import seaborn as sns
13 import matplotlib.pyplot as plt
14
15 from matplotlib.pyplot import gcf

Let’s load the dataset

20 # Load dataset
21 networks = sns.load_dataset("brain_networks",
22     index_col=0, header=[0, 1, 2])

Let’s create the network colors

27 # Create network colors
28 network_labels = networks.columns.get_level_values("network")
29 network_pal = sns.cubehelix_palette(network_labels.unique().size,
30     light=.9, dark=.1, reverse=True, start=1, rot=-2)
31 network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
32
33 network_colors = \
34     pd.Series(network_labels, index=networks.columns) \
35         .map(network_lut)

Let’s create the node colors

40 # Create node colors
41 node_labels = networks.columns.get_level_values("node")
42 node_pal = sns.cubehelix_palette(node_labels.unique().size)
43 node_lut = dict(zip(map(str, node_labels.unique()), node_pal))
44
45 node_colors = \
46     pd.Series(node_labels, index=networks.columns) \
47         .map(node_lut)

Let’s combine them.

52 # Combine
53 network_node_colors = \
54     pd.DataFrame(network_colors) \
55         .join(pd.DataFrame(node_colors))

Let’s display the clustermap

60 # Display
61 g = sns.clustermap(networks.corr(),
62     row_cluster=False, col_cluster=False, # turn off clusters
63     row_colors = network_node_colors, # add colored labels
64     col_colors = network_node_colors, # add colored labels
65     linewidths=0, xticklabels=False, yticklabels=False,
66     center=0, cmap="vlag", figsize=(7, 7))
67
68
69 # Add legend for networks
70 for label in network_labels.unique():
71     g.ax_col_dendrogram.bar(0, 0,
72         color=network_lut[label], label=label, linewidth=0)
73
74 l1 = g.ax_col_dendrogram.legend(title='Network',
75     loc="center", ncol=5, bbox_to_anchor=(0.53, 0.9),
76     bbox_transform=gcf().transFigure)
77
78 # Add legend for nodes
79 for label in node_labels.unique():
80     g.ax_row_dendrogram.bar(0, 0,
81         color=node_lut[label], label=label, linewidth=0)
82
83 l2 = g.ax_row_dendrogram.legend(title='Node',
84     loc="center", ncol=1, bbox_to_anchor=(0.86, 0.9),
85     bbox_transform=gcf().transFigure)
86
87 plt.show()
plot main05 c clustermap

Total running time of the script: ( 0 minutes 0.405 seconds)

Gallery generated by Sphinx-Gallery