01. Plot ROC

This example shows how to plot the ROC curves for various splits.

# sphinx_gallery_thumbnail_path = ‘_static/images/icon-github.svg’

 15 # Libraries
 16 import plotly.graph_objects as go
 17 import plotly.express as px
 18 import numpy as np
 19
 20 from plotly.io import show
 21
 22 # Layout
 23 from plotly.graph_objects import Layout
 24
 25 # -----------------------------------------
 26 # Config
 27 # -----------------------------------------
 28 # Colors
 29 colors = px.colors.qualitative.Plotly
 30 colors = px.colors.sequential.Plasma_r
 31 colors = px.colors.sequential.Viridis_r
 32
 33 # -----------------------------------------
 34 # Data
 35 # -----------------------------------------
 36 # Create some data
 37 fpr = np.arange(10)/10
 38 tpr = np.arange(10)/10
 39
 40 # Data
 41 data = {
 42     'split1': np.vstack((fpr + -0.10, tpr)).T,
 43     'split2': np.vstack((fpr + -0.15, tpr)).T,
 44     'split3': np.vstack((fpr + -0.20, tpr)).T,
 45     'split4': np.vstack((fpr + -0.25, tpr)).T,
 46     'split5': np.vstack((fpr + -0.30, tpr)).T,
 47     'split6': np.vstack((fpr + 0.15, tpr)).T,
 48     'split7': np.vstack((fpr + 0.15, tpr)).T,
 49     'split8': np.vstack((fpr + 0.20, tpr)).T,
 50     'split9': np.vstack((fpr + 0.25, tpr)).T,
 51     'split10': np.vstack((fpr + 0.30, tpr)).T
 52 }
 53
 54
 55 # -------------------------------------
 56 # Visualize
 57 # -------------------------------------
 58 # Create figure
 59 fig = go.Figure()
 60
 61 # Add diagonal line
 62 fig.add_shape(type='line', x0=0, x1=1, y0=0, y1=1,
 63     line=dict(dash='dash', color='gray', width=1),
 64 )
 65
 66 # Plot each split
 67 for i, (name, array) in enumerate(data.items()):
 68     # Name of split
 69     name = f"{name}" # (AUC={10:.2f})"
 70     # Add trace
 71     fig.add_trace(go.Scatter(x=array[:, 0],
 72                              y=array[:, 1],
 73                              name=name,
 74                              mode='lines+markers',
 75                              line=dict(color=colors[i],
 76                                        width=0.5)))
 77
 78 # Update layout
 79 fig.update_layout(
 80     xaxis_title='False Positive Rate',
 81     yaxis_title='True Positive Rate',
 82     yaxis=dict(scaleanchor="x", scaleratio=1),
 83     xaxis=dict(constrain='domain'),
 84     width=350, height=350,
 85     legend=dict(
 86         x=1.0, y=0.0,  # x=1, y=1.02
 87         orientation="v",
 88         font={'size': 12},
 89         yanchor="bottom",
 90         xanchor="right",
 91     ),
 92     margin={
 93         'l': 0,
 94         'r': 0,
 95         'b': 0,
 96         't': 0,
 97         'pad': 4
 98     },
 99     paper_bgcolor='rgba(0,0,0,0)',  # transparent
100     plot_bgcolor='rgba(0,0,0,0)'  # transparent
101 )
102
103 # Update xaxes
104 fig.update_xaxes(showgrid=True,
105                  gridwidth=1,
106                  nticks=10,
107                  range=[0, 1],
108                  gridcolor='lightgray')
109
110 # Update yaxes
111 fig.update_yaxes(showgrid=True,
112                  gridwidth=1,
113                  range=[0, 1],
114                  nticks=10,
115                  gridcolor='lightgray')
116
117 # Show
118 show(fig)

Total running time of the script: ( 0 minutes 2.003 seconds)

Gallery generated by Sphinx-Gallery