02. Data splitters

Useful methods to split data such as (i) hold out and csv sets or (ii) cross validation folds.

Warning

Not completed!

 11 # Libraries
 12 import pandas as pd
 13 import numpy as np
 14
 15 # Libraries specific
 16 from sklearn.model_selection import train_test_split
 17 from sklearn.model_selection import StratifiedKFold
 18 from sklearn.model_selection import KFold
 19
 20
 21 try:
 22     __file__
 23     TERMINAL = True
 24 except:
 25     TERMINAL = False
 26
 27 def split_dataframe_hos_cvs(dataframe,  **kwargs):
 28     """This method labels the dataframe hos and cvs sets.
 29
 30     Parameters
 31     ----------
 32     dataframe: np.array or pd.DataFrame
 33         The data to be divided into HOS/CVS.
 34
 35     Returns
 36     -------
 37     np.array:
 38         The outcome is a numpy array with rows labelled as
 39         cvs (cross-validation set) and hos (hold-out set).
 40         :param data:
 41         :param inplace:
 42     """
 43     # Check it is a dataframe
 44     if not isinstance(dataframe, pd.DataFrame):
 45         raise TypeError
 46
 47     # Length
 48     n = dataframe.shape[0]
 49
 50     # Split in hos and training sets
 51     cvs, hos = train_test_split(np.arange(n), **kwargs)
 52
 53     # Create result
 54     empty = np.array([None]*n)
 55     empty[cvs] = 'cvs'
 56     empty[hos] = 'hos'
 57
 58     # Include
 59     dataframe['sets'] = empty
 60
 61     # Return
 62     return dataframe
 63
 64
 65 def split_dataframe_cvs_folds(dataframe, splitter,
 66             selected_rows=None, **kwargs):
 67     """This method labels the different folds.
 68
 69         .. note:
 70
 71     Parameters
 72     ----------
 73     dataframe: np.array or pd.DataFrame
 74         The data to be divided into folds.
 75
 76     splitter: str or splitter
 77         The splitter which can be an str or an splitter from the
 78         sklearn library which implementeds the method split.
 79
 80     selected_rows: array of bools.
 81         The rows to be considered to create the folds. Note that if
 82         y is passed (for stratified cross validation) y will also be
 83         filtered by these rows.
 84
 85     kwargs:
 86
 87     Returns
 88     -------
 89     pd.DataFrame:
 90         The outcome is the same dataframe with an additional column
 91         <set> with the values cvs (cross-validation set) and hos
 92         (hold-out set).
 93     """
 94     # Check it is a dataframe
 95     if not isinstance(dataframe, pd.DataFrame):
 96         raise TypeError
 97
 98     # Get splitter from string
 99     if isinstance(splitter, str):
100         splitter = _DEFAULT_SPLITTERS[splitter]
101
102     # Define X and y
103     #X = dataframe[dataframe.sets == 'cvs'].index.to_numpy()
104     #y = dataframe[dataframe.sets == 'cvs'][label]
105
106     # Shape
107     r, c = dataframe.shape
108
109     # No rows selected (all by default)
110     if selected_rows is None:
111         selected_rows = np.full(r, True, dtype=bool)
112
113     # Select rows from y
114     if 'y' in kwargs:
115         if kwargs['y'] is not None:
116             kwargs['y'] = kwargs['y'][selected_rows]
117
118     # Create indexes to use for splitting
119     idxs = np.arange(r)[selected_rows].reshape(-1, 1)
120
121     # Get splits of idxs
122     splits = splitter.split(idxs, **kwargs)
123
124     # Loop and add
125     for i, (train, test) in enumerate(splits):
126         dataframe['split_{0}'.format(i)] = None
127         dataframe.loc[idxs[train].flatten(), 'split_{0}'.format(i)] = 'train'
128         dataframe.loc[idxs[test].flatten(), 'split_{0}'.format(i)] = 'test'
129
130     # Return
131     return dataframe
132
133
134 def split_dataframe_completeness(dataframe):
135     pass
136
137
138
139 class DataframeHOSCSVSplitter():
140     """
141     """
142     col_name = 'sets'
143     cvs_name = 'CVS'
144     hos_name = 'HOS'
145
146     def __init__(self, col_name=None,
147                        cvs_name=None,
148                        hos_name=None):
149         """Constructor
150
151         :param col_name:
152         :param cvs_name:
153         :param hos_name:
154         """
155         if col_name is not None:
156             self.col_name = col_name
157         if cvs_name is not None:
158             self.cvs_name = cvs_name
159         if hos_name is not None:
160             self.hos_name = hos_name
161
162     def split(self, dataframe, **kwargs):
163         """Splits the dataframe...
164         """
165         # Split
166         cvs, hos = train_test_split(dataframe.index.to_numpy(), **kwargs)
167
168         # Fill dataset
169         dataframe[self.col_name] = None
170         dataframe.loc[cvs, self.col_name] = self.cvs_name
171         dataframe.loc[hos, self.col_name] = self.hos_name
172
173         # Return
174         return dataframe
175
176
177
178 # Default splliters.
179 _DEFAULT_SPLITTERS = {
180     'skfold10': StratifiedKFold(n_splits=10, shuffle=True),
181     'skfold5': StratifiedKFold(n_splits=5, shuffle=True),
182     'skfold2': StratifiedKFold(n_splits=2, shuffle=True),
183 }
184
185
186 # --------------------------------------------------
187 # Main
188 # --------------------------------------------------
189 # Libraries
190 from sklearn.datasets import load_iris
191
192 # Load data
193 bunch = load_iris(as_frame=True)
194
195 # Dataframe
196 dataframe = bunch.data

Lets see the dataset

201 if TERMINAL:
202     print("\nData")
203     print(dataframe)
204 dataframe
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
0 5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2
... ... ... ... ...
145 6.7 3.0 5.2 2.3
146 6.3 2.5 5.0 1.9
147 6.5 3.0 5.2 2.0
148 6.2 3.4 5.4 2.3
149 5.9 3.0 5.1 1.8

150 rows × 4 columns



Lets split in HOS and CVS

209 # Split in HOS and CVS sets
210 df = split_dataframe_hos_cvs(dataframe)
211
212 if TERMINAL:
213     print("\nData")
214     print(df)
215 df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) sets
0 5.1 3.5 1.4 0.2 cvs
1 4.9 3.0 1.4 0.2 cvs
2 4.7 3.2 1.3 0.2 cvs
3 4.6 3.1 1.5 0.2 cvs
4 5.0 3.6 1.4 0.2 cvs
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 cvs
146 6.3 2.5 5.0 1.9 cvs
147 6.5 3.0 5.2 2.0 hos
148 6.2 3.4 5.4 2.3 cvs
149 5.9 3.0 5.1 1.8 cvs

150 rows × 5 columns



Lets split the CSV in various folds

220 # Split in folds
221 df = split_dataframe_cvs_folds(dataframe,
222     splitter='skfold5', y=bunch.target,
223     selected_rows=(dataframe.sets == 'cvs'))
224
225 if TERMINAL:
226     print("\nData")
227     print(df)
228 df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) sets split_0 split_1 split_2 split_3 split_4
0 5.1 3.5 1.4 0.2 cvs train test train train train
1 4.9 3.0 1.4 0.2 cvs train train train train test
2 4.7 3.2 1.3 0.2 cvs test train train train train
3 4.6 3.1 1.5 0.2 cvs train train train train test
4 5.0 3.6 1.4 0.2 cvs train train train train test
... ... ... ... ... ... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 cvs train train test train train
146 6.3 2.5 5.0 1.9 cvs train test train train train
147 6.5 3.0 5.2 2.0 hos None None None None None
148 6.2 3.4 5.4 2.3 cvs train train train test train
149 5.9 3.0 5.1 1.8 cvs test train train train train

150 rows × 10 columns



Lets split in HOS and CVS using the class

Note

This might not be working properly!

237 # Divide in HOS and CSV.
238 df = DataframeHOSCSVSplitter().split(dataframe)
239
240 if TERMINAL:
241     print("\nData")
242     print(df)
243 df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) sets split_0 split_1 split_2 split_3 split_4
0 5.1 3.5 1.4 0.2 CVS train test train train train
1 4.9 3.0 1.4 0.2 CVS train train train train test
2 4.7 3.2 1.3 0.2 CVS test train train train train
3 4.6 3.1 1.5 0.2 CVS train train train train test
4 5.0 3.6 1.4 0.2 CVS train train train train test
... ... ... ... ... ... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 CVS train train test train train
146 6.3 2.5 5.0 1.9 CVS train test train train train
147 6.5 3.0 5.2 2.0 CVS None None None None None
148 6.2 3.4 5.4 2.3 CVS train train train test train
149 5.9 3.0 5.1 1.8 CVS test train train train train

150 rows × 10 columns



Total running time of the script: ( 0 minutes 0.033 seconds)

Gallery generated by Sphinx-Gallery