Note
Click here to download the full example code
02. Confusion matrix dist. (violin)
This example displays the probability distributions for each component of a confusion matrix: True Positives (TP), True Negatives (TN), False Positives (FP), and False Negatives (FN). The script uses synthetically generated data for illustration. It is possible and recommended to constrain the plot’s axis to the [0, 1] range to focus strictly on valid probability values.
13 # Libraries
14 import plotly.graph_objects as go
15 import plotly.express as px
16 import pandas as pd
17 import numpy as np
18
19 from plotly.io import show
20
21 # Specific
22 from plotly.graph_objects import Layout
23
24 try:
25 __file__
26 TERMINAL = True
27 except:
28 TERMINAL = False
29
30 # -----------------------------------------
31 # Helper method
32 # -----------------------------------------
33 # This method is implemented in pySML.
34 def _tp_fp_tn_fn_distributions(y, y_pred, y_prob):
35 """This function returns probabilities for each of the confusion
36 matrix elements (tp, tn, fp, fn).
37
38 Parameters
39 ----------
40 y : array-like
41 The real categories
42
43 y_pred : array-like
44 The predicted categories
45
46 y_prob: array-like
47 The predict probabilities
48
49 Returns
50 -------
51 tp_probs, tn_probs, fp_probs, fn_probs
52 """
53 # Tags.
54 tp_idx = (y_pred == 1) & (y == 1)
55 tn_idx = (y_pred == 0) & (y == 0)
56 fp_idx = (y_pred == 1) & (y == 0)
57 fn_idx = (y_pred == 0) & (y == 1)
58 # Show information.
59 tp_probs = y_prob[tp_idx]
60 tn_probs = y_prob[tn_idx]
61 fp_probs = y_prob[fp_idx]
62 fn_probs = y_prob[fn_idx]
63 # Return
64 return tp_probs, tn_probs, fp_probs, fn_probs
65
66 # -----------------------------------------
67 # Config
68 # -----------------------------------------
69 # Colors
70 colors = px.colors.qualitative.Plotly
71 colors = px.colors.sequential.Plasma_r
72 colors = px.colors.sequential.Viridis_r
73
74 # -----------------------------------------
75 # Data
76 # -----------------------------------------
77 # Create data
78 data = pd.DataFrame()
79 data['y_true'] = np.random.randint(2, size=100)
80 data['y_pred'] = np.random.randint(2, size=100)
81 data['y_prob'] = np.random.normal(loc=0, scale=1, size=100)
82
83 # Get distributions
84 tp_probs, tn_probs, fp_probs, fn_probs = \
85 _tp_fp_tn_fn_distributions(data.y_true,
86 data.y_pred,
87 data.y_prob)
88
89 # Visualize
90 if TERMINAL:
91 print("\nData:")
92 print(data)
93 data
94
95
96 # -------------------------------------
97 # Visualize
98 # -------------------------------------
99 # Import subplots
100 from plotly.subplots import make_subplots
101
102 # Create figure
103 fig = make_subplots(rows=2, cols=2)
104 # subplot_titles=('TP', 'TN', 'FP', 'FN'))
105
106 # Add traces
107 fig.add_trace(go.Violin(x=tn_probs, line_width=1,
108 name='tn', line_color='black', fillcolor=colors[2],
109 opacity=0.5, meanline_visible=True, box_visible=True), row=1, col=1)
110 fig.add_trace(go.Violin(x=fp_probs, line_width=1,
111 name='fp', line_color='black', fillcolor=colors[4],
112 opacity=0.5, meanline_visible=True, box_visible=True), row=1, col=2)
113 fig.add_trace(go.Violin(x=fn_probs, line_width=1,
114 name='fn', line_color='black', fillcolor=colors[6],
115 opacity=0.5, meanline_visible=True, box_visible=True), row=2, col=1)
116 fig.add_trace(go.Violin(x=tp_probs, line_width=1,
117 name='tp', line_color='black', fillcolor=colors[0],
118 opacity=0.5, meanline_visible=True, box_visible=True), row=2, col=2)
119
120 # Update layout
121 fig.update_layout(
122 width=700, height=350,
123 #xaxis_title='False Positive Rate',
124 #yaxis_title='True Positive Rate',
125 #yaxis=dict(scaleanchor="x", scaleratio=1),
126 #xaxis=dict(constrain='domain'),
127 #legend=dict(
128 # x=1.0, y=0.0, # x=1, y=1.02
129 # orientation="v",
130 # font={'size': 12},
131 # yanchor="bottom",
132 # xanchor="right",
133 #),
134 margin={
135 'l': 0,
136 'r': 0,
137 'b': 0,
138 't': 0,
139 'pad': 0
140 },
141 paper_bgcolor='rgba(0,0,0,0)', # transparent
142 plot_bgcolor='rgba(0,0,0,0)' # transparent
143 )
144
145 # Update axes
146 #fig.update_xaxes(visible=True, range=[0.0, 0.5], row=1, col=1)
147 #fig.update_xaxes(visible=True, range=[0.5, 1.0], row=1, col=2)
148 #fig.update_xaxes(visible=True, range=[0.0, 0.5], row=2, col=1)
149 #fig.update_xaxes(visible=True, range=[0.5, 1.0], row=2, col=2)
150 fig.update_yaxes(visible=True)
151
152 # Show
153 show(fig)
Total running time of the script: ( 0 minutes 1.894 seconds)