Note
Click here to download the full example code
02. Plot Distributions
This example shows how to plot the probability distributions for each of the components of the confusion matrix. Note that data is created artificially. If necessary, it is possible to limit the axis to the range [0, 1] to keep only those true probability values.
12 # Libraries
13 import plotly.graph_objects as go
14 import plotly.express as px
15 import pandas as pd
16 import numpy as np
17
18 from plotly.io import show
19
20 # Specific
21 from plotly.graph_objects import Layout
22
23 try:
24 __file__
25 TERMINAL = True
26 except:
27 TERMINAL = False
28
29 # -----------------------------------------
30 # Helper method
31 # -----------------------------------------
32 # This method is implemented in pySML.
33 def _tp_fp_tn_fn_distributions(y, y_pred, y_prob):
34 """This function returns probabilities for each of the confusion
35 matrix elements (tp, tn, fp, fn).
36
37 Parameters
38 ----------
39 y : array-like
40 The real categories
41
42 y_pred : array-like
43 The predicted categories
44
45 y_prob: array-like
46 The predict probabilities
47
48 Returns
49 -------
50 tp_probs, tn_probs, fp_probs, fn_probs
51 """
52 # Tags.
53 tp_idx = (y_pred == 1) & (y == 1)
54 tn_idx = (y_pred == 0) & (y == 0)
55 fp_idx = (y_pred == 1) & (y == 0)
56 fn_idx = (y_pred == 0) & (y == 1)
57 # Show information.
58 tp_probs = y_prob[tp_idx]
59 tn_probs = y_prob[tn_idx]
60 fp_probs = y_prob[fp_idx]
61 fn_probs = y_prob[fn_idx]
62 # Return
63 return tp_probs, tn_probs, fp_probs, fn_probs
64
65 # -----------------------------------------
66 # Config
67 # -----------------------------------------
68 # Colors
69 colors = px.colors.qualitative.Plotly
70 colors = px.colors.sequential.Plasma_r
71 colors = px.colors.sequential.Viridis_r
72
73 # -----------------------------------------
74 # Data
75 # -----------------------------------------
76 # Create data
77 data = pd.DataFrame()
78 data['y_true'] = np.random.randint(2, size=100)
79 data['y_pred'] = np.random.randint(2, size=100)
80 data['y_prob'] = np.random.normal(loc=0, scale=1, size=100)
81
82 # Get distributions
83 tp_probs, tn_probs, fp_probs, fn_probs = \
84 _tp_fp_tn_fn_distributions(data.y_true,
85 data.y_pred,
86 data.y_prob)
87
88 # Visualize
89 if TERMINAL:
90 print("\nData:")
91 print(data)
92 data
93
94
95 # -------------------------------------
96 # Visualize
97 # -------------------------------------
98 # Import subplots
99 from plotly.subplots import make_subplots
100
101 # Create figure
102 fig = make_subplots(rows=2, cols=2)
103 # subplot_titles=('TP', 'TN', 'FP', 'FN'))
104
105 # Add traces
106 fig.add_trace(go.Violin(x=tn_probs, line_width=1,
107 name='tn', line_color='black', fillcolor=colors[2],
108 opacity=0.5, meanline_visible=True, box_visible=True), row=1, col=1)
109 fig.add_trace(go.Violin(x=fp_probs, line_width=1,
110 name='fp', line_color='black', fillcolor=colors[4],
111 opacity=0.5, meanline_visible=True, box_visible=True), row=1, col=2)
112 fig.add_trace(go.Violin(x=fn_probs, line_width=1,
113 name='fn', line_color='black', fillcolor=colors[6],
114 opacity=0.5, meanline_visible=True, box_visible=True), row=2, col=1)
115 fig.add_trace(go.Violin(x=tp_probs, line_width=1,
116 name='tp', line_color='black', fillcolor=colors[0],
117 opacity=0.5, meanline_visible=True, box_visible=True), row=2, col=2)
118
119 # Update layout
120 fig.update_layout(
121 width=700, height=350,
122 #xaxis_title='False Positive Rate',
123 #yaxis_title='True Positive Rate',
124 #yaxis=dict(scaleanchor="x", scaleratio=1),
125 #xaxis=dict(constrain='domain'),
126 #legend=dict(
127 # x=1.0, y=0.0, # x=1, y=1.02
128 # orientation="v",
129 # font={'size': 12},
130 # yanchor="bottom",
131 # xanchor="right",
132 #),
133 margin={
134 'l': 0,
135 'r': 0,
136 'b': 0,
137 't': 0,
138 'pad': 0
139 },
140 paper_bgcolor='rgba(0,0,0,0)', # transparent
141 plot_bgcolor='rgba(0,0,0,0)' # transparent
142 )
143
144 # Update axes
145 #fig.update_xaxes(visible=True, range=[0.0, 0.5], row=1, col=1)
146 #fig.update_xaxes(visible=True, range=[0.5, 1.0], row=1, col=2)
147 #fig.update_xaxes(visible=True, range=[0.0, 0.5], row=2, col=1)
148 #fig.update_xaxes(visible=True, range=[0.5, 1.0], row=2, col=2)
149 fig.update_yaxes(visible=True)
150
151 # Show
152 show(fig)
Total running time of the script: ( 0 minutes 0.316 seconds)