02. Confusion matrix dist. (violin)

This example displays the probability distributions for each component of a confusion matrix: True Positives (TP), True Negatives (TN), False Positives (FP), and False Negatives (FN). The script uses synthetically generated data for illustration. It is possible and recommended to constrain the plot’s axis to the [0, 1] range to focus strictly on valid probability values.

 13 # Libraries
 14 import plotly.graph_objects as go
 15 import plotly.express as px
 16 import pandas as pd
 17 import numpy as np
 18
 19 from plotly.io import show
 20
 21 # Specific
 22 from plotly.graph_objects import Layout
 23
 24 try:
 25     __file__
 26     TERMINAL = True
 27 except:
 28     TERMINAL = False
 29
 30 # -----------------------------------------
 31 # Helper method
 32 # -----------------------------------------
 33 # This method is implemented in pySML.
 34 def _tp_fp_tn_fn_distributions(y, y_pred, y_prob):
 35     """This function returns probabilities for each of the confusion
 36     matrix elements (tp, tn, fp, fn).
 37
 38     Parameters
 39     ----------
 40     y : array-like
 41       The real categories
 42
 43     y_pred : array-like
 44       The predicted categories
 45
 46     y_prob: array-like
 47       The predict probabilities
 48
 49     Returns
 50     -------
 51     tp_probs, tn_probs, fp_probs, fn_probs
 52     """
 53     # Tags.
 54     tp_idx = (y_pred == 1) & (y == 1)
 55     tn_idx = (y_pred == 0) & (y == 0)
 56     fp_idx = (y_pred == 1) & (y == 0)
 57     fn_idx = (y_pred == 0) & (y == 1)
 58     # Show information.
 59     tp_probs = y_prob[tp_idx]
 60     tn_probs = y_prob[tn_idx]
 61     fp_probs = y_prob[fp_idx]
 62     fn_probs = y_prob[fn_idx]
 63     # Return
 64     return tp_probs, tn_probs, fp_probs, fn_probs
 65
 66 # -----------------------------------------
 67 # Config
 68 # -----------------------------------------
 69 # Colors
 70 colors = px.colors.qualitative.Plotly
 71 colors = px.colors.sequential.Plasma_r
 72 colors = px.colors.sequential.Viridis_r
 73
 74 # -----------------------------------------
 75 # Data
 76 # -----------------------------------------
 77 # Create data
 78 data = pd.DataFrame()
 79 data['y_true'] = np.random.randint(2, size=100)
 80 data['y_pred'] = np.random.randint(2, size=100)
 81 data['y_prob'] = np.random.normal(loc=0, scale=1, size=100)
 82
 83 # Get distributions
 84 tp_probs, tn_probs, fp_probs, fn_probs = \
 85     _tp_fp_tn_fn_distributions(data.y_true,
 86                                data.y_pred,
 87                                data.y_prob)
 88
 89 # Visualize
 90 if TERMINAL:
 91     print("\nData:")
 92     print(data)
 93 data
 94
 95
 96 # -------------------------------------
 97 # Visualize
 98 # -------------------------------------
 99 # Import subplots
100 from plotly.subplots import make_subplots
101
102 # Create figure
103 fig = make_subplots(rows=2, cols=2)
104 # subplot_titles=('TP', 'TN', 'FP', 'FN'))
105
106 #  Add traces
107 fig.add_trace(go.Violin(x=tn_probs, line_width=1,
108     name='tn', line_color='black', fillcolor=colors[2],
109     opacity=0.5, meanline_visible=True, box_visible=True), row=1, col=1)
110 fig.add_trace(go.Violin(x=fp_probs, line_width=1,
111     name='fp', line_color='black', fillcolor=colors[4],
112     opacity=0.5, meanline_visible=True, box_visible=True), row=1, col=2)
113 fig.add_trace(go.Violin(x=fn_probs, line_width=1,
114     name='fn', line_color='black', fillcolor=colors[6],
115     opacity=0.5, meanline_visible=True, box_visible=True), row=2, col=1)
116 fig.add_trace(go.Violin(x=tp_probs, line_width=1,
117     name='tp', line_color='black', fillcolor=colors[0],
118     opacity=0.5, meanline_visible=True, box_visible=True), row=2, col=2)
119
120 # Update layout
121 fig.update_layout(
122     width=700, height=350,
123     #xaxis_title='False Positive Rate',
124     #yaxis_title='True Positive Rate',
125     #yaxis=dict(scaleanchor="x", scaleratio=1),
126     #xaxis=dict(constrain='domain'),
127     #legend=dict(
128     #    x=1.0, y=0.0,  # x=1, y=1.02
129     #    orientation="v",
130     #    font={'size': 12},
131     #    yanchor="bottom",
132     #    xanchor="right",
133     #),
134     margin={
135         'l': 0,
136         'r': 0,
137         'b': 0,
138         't': 0,
139         'pad': 0
140     },
141     paper_bgcolor='rgba(0,0,0,0)',  # transparent
142     plot_bgcolor='rgba(0,0,0,0)'  # transparent
143 )
144
145 # Update axes
146 #fig.update_xaxes(visible=True, range=[0.0, 0.5], row=1, col=1)
147 #fig.update_xaxes(visible=True, range=[0.5, 1.0], row=1, col=2)
148 #fig.update_xaxes(visible=True, range=[0.0, 0.5], row=2, col=1)
149 #fig.update_xaxes(visible=True, range=[0.5, 1.0], row=2, col=2)
150 fig.update_yaxes(visible=True)
151
152 # Show
153 show(fig)

Total running time of the script: ( 0 minutes 1.894 seconds)

Gallery generated by Sphinx-Gallery