02. Plot Distributions

This example shows how to plot the probability distributions for each of the components of the confusion matrix. Note that data is created artificially. If necessary, it is possible to limit the axis to the range [0, 1] to keep only those true probability values.

 12 # Libraries
 13 import plotly.graph_objects as go
 14 import plotly.express as px
 15 import pandas as pd
 16 import numpy as np
 17
 18 from plotly.io import show
 19
 20 # Specific
 21 from plotly.graph_objects import Layout
 22
 23 try:
 24     __file__
 25     TERMINAL = True
 26 except:
 27     TERMINAL = False
 28
 29 # -----------------------------------------
 30 # Helper method
 31 # -----------------------------------------
 32 # This method is implemented in pySML.
 33 def _tp_fp_tn_fn_distributions(y, y_pred, y_prob):
 34     """This function returns probabilities for each of the confusion
 35     matrix elements (tp, tn, fp, fn).
 36
 37     Parameters
 38     ----------
 39     y : array-like
 40       The real categories
 41
 42     y_pred : array-like
 43       The predicted categories
 44
 45     y_prob: array-like
 46       The predict probabilities
 47
 48     Returns
 49     -------
 50     tp_probs, tn_probs, fp_probs, fn_probs
 51     """
 52     # Tags.
 53     tp_idx = (y_pred == 1) & (y == 1)
 54     tn_idx = (y_pred == 0) & (y == 0)
 55     fp_idx = (y_pred == 1) & (y == 0)
 56     fn_idx = (y_pred == 0) & (y == 1)
 57     # Show information.
 58     tp_probs = y_prob[tp_idx]
 59     tn_probs = y_prob[tn_idx]
 60     fp_probs = y_prob[fp_idx]
 61     fn_probs = y_prob[fn_idx]
 62     # Return
 63     return tp_probs, tn_probs, fp_probs, fn_probs
 64
 65 # -----------------------------------------
 66 # Config
 67 # -----------------------------------------
 68 # Colors
 69 colors = px.colors.qualitative.Plotly
 70 colors = px.colors.sequential.Plasma_r
 71 colors = px.colors.sequential.Viridis_r
 72
 73 # -----------------------------------------
 74 # Data
 75 # -----------------------------------------
 76 # Create data
 77 data = pd.DataFrame()
 78 data['y_true'] = np.random.randint(2, size=100)
 79 data['y_pred'] = np.random.randint(2, size=100)
 80 data['y_prob'] = np.random.normal(loc=0, scale=1, size=100)
 81
 82 # Get distributions
 83 tp_probs, tn_probs, fp_probs, fn_probs = \
 84     _tp_fp_tn_fn_distributions(data.y_true,
 85                                data.y_pred,
 86                                data.y_prob)
 87
 88 # Visualize
 89 if TERMINAL:
 90     print("\nData:")
 91     print(data)
 92 data
 93
 94
 95 # -------------------------------------
 96 # Visualize
 97 # -------------------------------------
 98 # Import subplots
 99 from plotly.subplots import make_subplots
100
101 # Create figure
102 fig = make_subplots(rows=2, cols=2)
103 # subplot_titles=('TP', 'TN', 'FP', 'FN'))
104
105 #  Add traces
106 fig.add_trace(go.Violin(x=tn_probs, line_width=1,
107     name='tn', line_color='black', fillcolor=colors[2],
108     opacity=0.5, meanline_visible=True, box_visible=True), row=1, col=1)
109 fig.add_trace(go.Violin(x=fp_probs, line_width=1,
110     name='fp', line_color='black', fillcolor=colors[4],
111     opacity=0.5, meanline_visible=True, box_visible=True), row=1, col=2)
112 fig.add_trace(go.Violin(x=fn_probs, line_width=1,
113     name='fn', line_color='black', fillcolor=colors[6],
114     opacity=0.5, meanline_visible=True, box_visible=True), row=2, col=1)
115 fig.add_trace(go.Violin(x=tp_probs, line_width=1,
116     name='tp', line_color='black', fillcolor=colors[0],
117     opacity=0.5, meanline_visible=True, box_visible=True), row=2, col=2)
118
119 # Update layout
120 fig.update_layout(
121     width=700, height=350,
122     #xaxis_title='False Positive Rate',
123     #yaxis_title='True Positive Rate',
124     #yaxis=dict(scaleanchor="x", scaleratio=1),
125     #xaxis=dict(constrain='domain'),
126     #legend=dict(
127     #    x=1.0, y=0.0,  # x=1, y=1.02
128     #    orientation="v",
129     #    font={'size': 12},
130     #    yanchor="bottom",
131     #    xanchor="right",
132     #),
133     margin={
134         'l': 0,
135         'r': 0,
136         'b': 0,
137         't': 0,
138         'pad': 0
139     },
140     paper_bgcolor='rgba(0,0,0,0)',  # transparent
141     plot_bgcolor='rgba(0,0,0,0)'  # transparent
142 )
143
144 # Update axes
145 #fig.update_xaxes(visible=True, range=[0.0, 0.5], row=1, col=1)
146 #fig.update_xaxes(visible=True, range=[0.5, 1.0], row=1, col=2)
147 #fig.update_xaxes(visible=True, range=[0.0, 0.5], row=2, col=1)
148 #fig.update_xaxes(visible=True, range=[0.5, 1.0], row=2, col=2)
149 fig.update_yaxes(visible=True)
150
151 # Show
152 show(fig)

Total running time of the script: ( 0 minutes 0.316 seconds)

Gallery generated by Sphinx-Gallery