Note
Click here to download the full example code
01. ROC curve (scatter)
This example demonstrates how to plot the ROC for different splits.
8 # Libraries
9 import plotly.graph_objects as go
10 import plotly.express as px
11 import numpy as np
12
13 from plotly.io import show
14
15 # Layout
16 from plotly.graph_objects import Layout
17
18 # -----------------------------------------
19 # Config
20 # -----------------------------------------
21 # Colors
22 #colors = px.colors.qualitative.Plotly
23 #colors = px.colors.sequential.Plasma_r
24 colors = px.colors.sequential.Viridis_r
25
26 # -----------------------------------------
27 # Data
28 # -----------------------------------------
29 # Create some data
30 fpr = np.arange(10)/10
31 tpr = np.arange(10)/10
32
33 # Data
34 data = {
35 'split1': np.vstack((fpr + -0.10, tpr)).T,
36 'split2': np.vstack((fpr + -0.15, tpr)).T,
37 'split3': np.vstack((fpr + -0.20, tpr)).T,
38 'split4': np.vstack((fpr + -0.25, tpr)).T,
39 'split5': np.vstack((fpr + -0.30, tpr)).T,
40 'split6': np.vstack((fpr + 0.15, tpr)).T,
41 'split7': np.vstack((fpr + 0.15, tpr)).T,
42 'split8': np.vstack((fpr + 0.20, tpr)).T,
43 'split9': np.vstack((fpr + 0.25, tpr)).T,
44 'split10': np.vstack((fpr + 0.30, tpr)).T
45 }
46
47
48 # -------------------------------------
49 # Visualize
50 # -------------------------------------
51 # Create figure
52 fig = go.Figure()
53
54 # Add diagonal line
55 fig.add_shape(type='line', x0=0, x1=1, y0=0, y1=1,
56 line=dict(dash='dash', color='gray', width=1),
57 )
58
59 # Plot each split
60 for i, (name, array) in enumerate(data.items()):
61 # Name of split
62 name = f"{name}" # (AUC={10:.2f})"
63 # Add trace
64 fig.add_trace(go.Scatter(x=array[:, 0],
65 y=array[:, 1],
66 name=name,
67 mode='lines+markers',
68 line=dict(width=0.5))
69 )
70
71 # Update layout
72 fig.update_layout(
73 xaxis_title='False Positive Rate',
74 yaxis_title='True Positive Rate',
75 yaxis=dict(scaleanchor="x", scaleratio=1),
76 xaxis=dict(constrain='domain'),
77 width=350, height=350,
78 legend=dict(
79 x=1.0, y=0.0, # x=1, y=1.02
80 orientation="v",
81 font={'size': 12},
82 yanchor="bottom",
83 xanchor="right",
84 ),
85 margin={
86 'l': 0,
87 'r': 0,
88 'b': 0,
89 't': 0,
90 'pad': 4
91 },
92 paper_bgcolor='rgba(0,0,0,0)', # transparent
93 plot_bgcolor='rgba(0,0,0,0)', # transparent
94 colorway=colors
95 )
96
97 # Update xaxes
98 fig.update_xaxes(showgrid=True,
99 gridwidth=1,
100 nticks=10,
101 range=[0, 1],
102 gridcolor='lightgray')
103
104 # Update yaxes
105 fig.update_yaxes(showgrid=True,
106 gridwidth=1,
107 range=[0, 1],
108 nticks=10,
109 gridcolor='lightgray')
110
111 # Show
112 show(fig)
Total running time of the script: ( 0 minutes 3.375 seconds)