Note
Click here to download the full example code
01. Plot ROC
This example shows how to plot the ROC curves for various splits.
# sphinx_gallery_thumbnail_path = ‘_static/images/icon-github.svg’
15 # Libraries
16 import plotly.graph_objects as go
17 import plotly.express as px
18 import numpy as np
19
20 from plotly.io import show
21
22 # Layout
23 from plotly.graph_objects import Layout
24
25 # -----------------------------------------
26 # Config
27 # -----------------------------------------
28 # Colors
29 colors = px.colors.qualitative.Plotly
30 colors = px.colors.sequential.Plasma_r
31 colors = px.colors.sequential.Viridis_r
32
33 # -----------------------------------------
34 # Data
35 # -----------------------------------------
36 # Create some data
37 fpr = np.arange(10)/10
38 tpr = np.arange(10)/10
39
40 # Data
41 data = {
42 'split1': np.vstack((fpr + -0.10, tpr)).T,
43 'split2': np.vstack((fpr + -0.15, tpr)).T,
44 'split3': np.vstack((fpr + -0.20, tpr)).T,
45 'split4': np.vstack((fpr + -0.25, tpr)).T,
46 'split5': np.vstack((fpr + -0.30, tpr)).T,
47 'split6': np.vstack((fpr + 0.15, tpr)).T,
48 'split7': np.vstack((fpr + 0.15, tpr)).T,
49 'split8': np.vstack((fpr + 0.20, tpr)).T,
50 'split9': np.vstack((fpr + 0.25, tpr)).T,
51 'split10': np.vstack((fpr + 0.30, tpr)).T
52 }
53
54
55 # -------------------------------------
56 # Visualize
57 # -------------------------------------
58 # Create figure
59 fig = go.Figure()
60
61 # Add diagonal line
62 fig.add_shape(type='line', x0=0, x1=1, y0=0, y1=1,
63 line=dict(dash='dash', color='gray', width=1),
64 )
65
66 # Plot each split
67 for i, (name, array) in enumerate(data.items()):
68 # Name of split
69 name = f"{name}" # (AUC={10:.2f})"
70 # Add trace
71 fig.add_trace(go.Scatter(x=array[:, 0],
72 y=array[:, 1],
73 name=name,
74 mode='lines+markers',
75 line=dict(color=colors[i],
76 width=0.5)))
77
78 # Update layout
79 fig.update_layout(
80 xaxis_title='False Positive Rate',
81 yaxis_title='True Positive Rate',
82 yaxis=dict(scaleanchor="x", scaleratio=1),
83 xaxis=dict(constrain='domain'),
84 width=350, height=350,
85 legend=dict(
86 x=1.0, y=0.0, # x=1, y=1.02
87 orientation="v",
88 font={'size': 12},
89 yanchor="bottom",
90 xanchor="right",
91 ),
92 margin={
93 'l': 0,
94 'r': 0,
95 'b': 0,
96 't': 0,
97 'pad': 4
98 },
99 paper_bgcolor='rgba(0,0,0,0)', # transparent
100 plot_bgcolor='rgba(0,0,0,0)' # transparent
101 )
102
103 # Update xaxes
104 fig.update_xaxes(showgrid=True,
105 gridwidth=1,
106 nticks=10,
107 range=[0, 1],
108 gridcolor='lightgray')
109
110 # Update yaxes
111 fig.update_yaxes(showgrid=True,
112 gridwidth=1,
113 range=[0, 1],
114 nticks=10,
115 gridcolor='lightgray')
116
117 # Show
118 show(fig)
Total running time of the script: ( 0 minutes 2.003 seconds)