Note
Click here to download the full example code
01. Test Skfold
Demonstrate that StratifiedKFold
is deterministic and
always returns the same splits. The way to change the splits
is by changing the random state.
Out:
Example I:
0 == 1 : True
1 == 2 : True
2 == 3 : True
3 == 4 : True
Example II:
0 == 1 : True
1 == 2 : True
2 == 3 : True
3 == 4 : True
9 # General
10 import numpy as np
11 import pandas as pd
12
13 # Specific
14 from sklearn.model_selection import StratifiedKFold
15
16
17 # ---------------------------------------------------
18 #
19 # ---------------------------------------------------
20 def repeated_splits(X, y, n_loops=2, n_splits=5):
21 """This method creates several times the
22 splits using the same function. Then
23 it is used to check that the splitting
24 is always consistent.
25 """
26 # Record for comparison
27 records = []
28
29 # Split
30 for i in range(n_loops):
31 # Create dataframe
32 dataframe = pd.DataFrame()
33 # Create splitter
34 skf = StratifiedKFold(n_splits=n_splits)
35 # Loop
36 for j, (train, test) in enumerate(skf.split(X, y)):
37 dataframe['fold_{0}'.format(j)] = \
38 np.concatenate((train, test))
39 # Append
40 records.append(dataframe)
41
42 # Return
43 return records
44
45
46 # ---------------------------------------------------
47 # Artificial example
48 # ---------------------------------------------------
49 # Size
50 n = 2000
51 n_splits = 5
52 n_loops = 5
53
54 # Create dataset
55 X = np.random.randint(10, size=(n, 7))
56 y = (np.random.rand(n) > 0.1).astype(int)
57
58 # Create splits
59 records = repeated_splits(X, y, n_loops=n_loops,
60 n_splits=n_splits)
61
62 # Compare if all records are equal
63 print("\nExample I:")
64 for i in range(len(records)-1):
65 print('{0} == {1} : {2}'.format(i, i+1, \
66 records[i].equals(records[i+1])))
67
68
69
70 # ---------------------------------------------------
71 # Real example
72 # ---------------------------------------------------
73 # Libraries
74 from sklearn.datasets import load_iris
75
76 # Load data
77 bunch = load_iris(as_frame=True)
78
79 # Label conversion
80 lblmap = dict(enumerate(bunch.target_names))
81
82 # Dataframe
83 df = bunch.data
84 df['target'] = bunch.target
85 df['label'] = df.target.map(lblmap)
86
87 # Get X, and y
88 X = df.to_numpy()
89 y = df.label.to_numpy()
90
91 # Create splits
92 records = repeated_splits(X, y, n_loops=n_loops,
93 n_splits=n_splits)
94
95 # Compare if all records are equal
96 print("\nExample II:")
97 for i in range(len(records)-1):
98 print('{0} == {1} : {2}'.format(i, i+1, \
99 records[i].equals(records[i+1])))
Total running time of the script: ( 0 minutes 0.032 seconds)