01. Test Skfold

Demonstrate that StratifiedKFold is deterministic and always returns the same splits. The way to change the splits is by changing the random state.

Out:

Example I:
0 == 1 : True
1 == 2 : True
2 == 3 : True
3 == 4 : True

Example II:
0 == 1 : True
1 == 2 : True
2 == 3 : True
3 == 4 : True

 9 # General
10 import numpy as np
11 import pandas as pd
12
13 # Specific
14 from sklearn.model_selection import StratifiedKFold
15
16
17 # ---------------------------------------------------
18 #
19 # ---------------------------------------------------
20 def repeated_splits(X, y, n_loops=2, n_splits=5):
21     """This method creates several times the
22        splits using the same function. Then
23        it is used to check that the splitting
24        is always consistent.
25     """
26     # Record for comparison
27     records = []
28
29     # Split
30     for i in range(n_loops):
31         # Create dataframe
32         dataframe = pd.DataFrame()
33         # Create splitter
34         skf = StratifiedKFold(n_splits=n_splits)
35         # Loop
36         for j, (train, test) in enumerate(skf.split(X, y)):
37             dataframe['fold_{0}'.format(j)] = \
38                 np.concatenate((train, test))
39         # Append
40         records.append(dataframe)
41
42     # Return
43     return records
44
45
46 # ---------------------------------------------------
47 # Artificial example
48 # ---------------------------------------------------
49 # Size
50 n = 2000
51 n_splits = 5
52 n_loops = 5
53
54 # Create dataset
55 X = np.random.randint(10, size=(n, 7))
56 y = (np.random.rand(n) > 0.1).astype(int)
57
58 # Create splits
59 records = repeated_splits(X, y, n_loops=n_loops,
60                                 n_splits=n_splits)
61
62 # Compare if all records are equal
63 print("\nExample I:")
64 for i in range(len(records)-1):
65     print('{0} == {1} : {2}'.format(i, i+1, \
66         records[i].equals(records[i+1])))
67
68
69
70 # ---------------------------------------------------
71 # Real example
72 # ---------------------------------------------------
73 # Libraries
74 from sklearn.datasets import load_iris
75
76 # Load data
77 bunch = load_iris(as_frame=True)
78
79 # Label conversion
80 lblmap = dict(enumerate(bunch.target_names))
81
82 # Dataframe
83 df = bunch.data
84 df['target'] = bunch.target
85 df['label'] = df.target.map(lblmap)
86
87 # Get X, and y
88 X = df.to_numpy()
89 y = df.label.to_numpy()
90
91 # Create splits
92 records = repeated_splits(X, y, n_loops=n_loops,
93                                 n_splits=n_splits)
94
95 # Compare if all records are equal
96 print("\nExample II:")
97 for i in range(len(records)-1):
98     print('{0} == {1} : {2}'.format(i, i+1, \
99         records[i].equals(records[i+1])))

Total running time of the script: ( 0 minutes 0.032 seconds)

Gallery generated by Sphinx-Gallery