01. ROC curve (scatter)

This example demonstrates how to plot the ROC for different splits.

  8 # Libraries
  9 import plotly.graph_objects as go
 10 import plotly.express as px
 11 import numpy as np
 12
 13 from plotly.io import show
 14
 15 # Layout
 16 from plotly.graph_objects import Layout
 17
 18 # -----------------------------------------
 19 # Config
 20 # -----------------------------------------
 21 # Colors
 22 #colors = px.colors.qualitative.Plotly
 23 #colors = px.colors.sequential.Plasma_r
 24 colors = px.colors.sequential.Viridis_r
 25
 26 # -----------------------------------------
 27 # Data
 28 # -----------------------------------------
 29 # Create some data
 30 fpr = np.arange(10)/10
 31 tpr = np.arange(10)/10
 32
 33 # Data
 34 data = {
 35     'split1': np.vstack((fpr + -0.10, tpr)).T,
 36     'split2': np.vstack((fpr + -0.15, tpr)).T,
 37     'split3': np.vstack((fpr + -0.20, tpr)).T,
 38     'split4': np.vstack((fpr + -0.25, tpr)).T,
 39     'split5': np.vstack((fpr + -0.30, tpr)).T,
 40     'split6': np.vstack((fpr + 0.15, tpr)).T,
 41     'split7': np.vstack((fpr + 0.15, tpr)).T,
 42     'split8': np.vstack((fpr + 0.20, tpr)).T,
 43     'split9': np.vstack((fpr + 0.25, tpr)).T,
 44     'split10': np.vstack((fpr + 0.30, tpr)).T
 45 }
 46
 47
 48 # -------------------------------------
 49 # Visualize
 50 # -------------------------------------
 51 # Create figure
 52 fig = go.Figure()
 53
 54 # Add diagonal line
 55 fig.add_shape(type='line', x0=0, x1=1, y0=0, y1=1,
 56     line=dict(dash='dash', color='gray', width=1),
 57 )
 58
 59 # Plot each split
 60 for i, (name, array) in enumerate(data.items()):
 61     # Name of split
 62     name = f"{name}" # (AUC={10:.2f})"
 63     # Add trace
 64     fig.add_trace(go.Scatter(x=array[:, 0],
 65                              y=array[:, 1],
 66                              name=name,
 67                              mode='lines+markers',
 68                              line=dict(width=0.5))
 69     )
 70
 71 # Update layout
 72 fig.update_layout(
 73     xaxis_title='False Positive Rate',
 74     yaxis_title='True Positive Rate',
 75     yaxis=dict(scaleanchor="x", scaleratio=1),
 76     xaxis=dict(constrain='domain'),
 77     width=350, height=350,
 78     legend=dict(
 79         x=1.0, y=0.0,  # x=1, y=1.02
 80         orientation="v",
 81         font={'size': 12},
 82         yanchor="bottom",
 83         xanchor="right",
 84     ),
 85     margin={
 86         'l': 0,
 87         'r': 0,
 88         'b': 0,
 89         't': 0,
 90         'pad': 4
 91     },
 92     paper_bgcolor='rgba(0,0,0,0)',  # transparent
 93     plot_bgcolor='rgba(0,0,0,0)',  # transparent
 94     colorway=colors
 95 )
 96
 97 # Update xaxes
 98 fig.update_xaxes(showgrid=True,
 99                  gridwidth=1,
100                  nticks=10,
101                  range=[0, 1],
102                  gridcolor='lightgray')
103
104 # Update yaxes
105 fig.update_yaxes(showgrid=True,
106                  gridwidth=1,
107                  range=[0, 1],
108                  nticks=10,
109                  gridcolor='lightgray')
110
111 # Show
112 show(fig)

Total running time of the script: ( 0 minutes 3.375 seconds)

Gallery generated by Sphinx-Gallery