.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples/scikits/plot_data_splitters.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_scikits_plot_data_splitters.py: 02. Data splitters ================== Useful methods to split data such as (i) hold out and csv sets or (ii) cross validation folds. .. warning:: Not completed! .. GENERATED FROM PYTHON SOURCE LINES 11-198 .. code-block:: default :lineno-start: 11 # Libraries import pandas as pd import numpy as np # Libraries specific from sklearn.model_selection import train_test_split from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import KFold try: __file__ TERMINAL = True except: TERMINAL = False def split_dataframe_hos_cvs(dataframe, **kwargs): """This method labels the dataframe hos and cvs sets. Parameters ---------- dataframe: np.array or pd.DataFrame The data to be divided into HOS/CVS. Returns ------- np.array: The outcome is a numpy array with rows labelled as cvs (cross-validation set) and hos (hold-out set). :param data: :param inplace: """ # Check it is a dataframe if not isinstance(dataframe, pd.DataFrame): raise TypeError # Length n = dataframe.shape[0] # Split in hos and training sets cvs, hos = train_test_split(np.arange(n), **kwargs) # Create result empty = np.array([None]*n) empty[cvs] = 'cvs' empty[hos] = 'hos' # Include dataframe['sets'] = empty # Return return dataframe def split_dataframe_cvs_folds(dataframe, splitter, selected_rows=None, **kwargs): """This method labels the different folds. .. note: Parameters ---------- dataframe: np.array or pd.DataFrame The data to be divided into folds. splitter: str or splitter The splitter which can be an str or an splitter from the sklearn library which implementeds the method split. selected_rows: array of bools. The rows to be considered to create the folds. Note that if y is passed (for stratified cross validation) y will also be filtered by these rows. kwargs: Returns ------- pd.DataFrame: The outcome is the same dataframe with an additional column with the values cvs (cross-validation set) and hos (hold-out set). """ # Check it is a dataframe if not isinstance(dataframe, pd.DataFrame): raise TypeError # Get splitter from string if isinstance(splitter, str): splitter = _DEFAULT_SPLITTERS[splitter] # Define X and y #X = dataframe[dataframe.sets == 'cvs'].index.to_numpy() #y = dataframe[dataframe.sets == 'cvs'][label] # Shape r, c = dataframe.shape # No rows selected (all by default) if selected_rows is None: selected_rows = np.full(r, True, dtype=bool) # Select rows from y if 'y' in kwargs: if kwargs['y'] is not None: kwargs['y'] = kwargs['y'][selected_rows] # Create indexes to use for splitting idxs = np.arange(r)[selected_rows].reshape(-1, 1) # Get splits of idxs splits = splitter.split(idxs, **kwargs) # Loop and add for i, (train, test) in enumerate(splits): dataframe['split_{0}'.format(i)] = None dataframe.loc[idxs[train].flatten(), 'split_{0}'.format(i)] = 'train' dataframe.loc[idxs[test].flatten(), 'split_{0}'.format(i)] = 'test' # Return return dataframe def split_dataframe_completeness(dataframe): pass class DataframeHOSCSVSplitter(): """ """ col_name = 'sets' cvs_name = 'CVS' hos_name = 'HOS' def __init__(self, col_name=None, cvs_name=None, hos_name=None): """Constructor :param col_name: :param cvs_name: :param hos_name: """ if col_name is not None: self.col_name = col_name if cvs_name is not None: self.cvs_name = cvs_name if hos_name is not None: self.hos_name = hos_name def split(self, dataframe, **kwargs): """Splits the dataframe... """ # Split cvs, hos = train_test_split(dataframe.index.to_numpy(), **kwargs) # Fill dataset dataframe[self.col_name] = None dataframe.loc[cvs, self.col_name] = self.cvs_name dataframe.loc[hos, self.col_name] = self.hos_name # Return return dataframe # Default splliters. _DEFAULT_SPLITTERS = { 'skfold10': StratifiedKFold(n_splits=10, shuffle=True), 'skfold5': StratifiedKFold(n_splits=5, shuffle=True), 'skfold2': StratifiedKFold(n_splits=2, shuffle=True), } # -------------------------------------------------- # Main # -------------------------------------------------- # Libraries from sklearn.datasets import load_iris # Load data bunch = load_iris(as_frame=True) # Dataframe dataframe = bunch.data .. GENERATED FROM PYTHON SOURCE LINES 199-200 Lets see the dataset .. GENERATED FROM PYTHON SOURCE LINES 200-206 .. code-block:: default :lineno-start: 201 if TERMINAL: print("\nData") print(dataframe) dataframe .. raw:: html
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
0 5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2
... ... ... ... ...
145 6.7 3.0 5.2 2.3
146 6.3 2.5 5.0 1.9
147 6.5 3.0 5.2 2.0
148 6.2 3.4 5.4 2.3
149 5.9 3.0 5.1 1.8

150 rows × 4 columns



.. GENERATED FROM PYTHON SOURCE LINES 207-208 Lets split in HOS and CVS .. GENERATED FROM PYTHON SOURCE LINES 208-217 .. code-block:: default :lineno-start: 209 # Split in HOS and CVS sets df = split_dataframe_hos_cvs(dataframe) if TERMINAL: print("\nData") print(df) df .. raw:: html
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) sets
0 5.1 3.5 1.4 0.2 cvs
1 4.9 3.0 1.4 0.2 cvs
2 4.7 3.2 1.3 0.2 cvs
3 4.6 3.1 1.5 0.2 cvs
4 5.0 3.6 1.4 0.2 cvs
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 cvs
146 6.3 2.5 5.0 1.9 cvs
147 6.5 3.0 5.2 2.0 hos
148 6.2 3.4 5.4 2.3 cvs
149 5.9 3.0 5.1 1.8 cvs

150 rows × 5 columns



.. GENERATED FROM PYTHON SOURCE LINES 218-219 Lets split the CSV in various folds .. GENERATED FROM PYTHON SOURCE LINES 219-232 .. code-block:: default :lineno-start: 220 # Split in folds df = split_dataframe_cvs_folds(dataframe, splitter='skfold5', y=bunch.target, selected_rows=(dataframe.sets == 'cvs')) if TERMINAL: print("\nData") print(df) df .. raw:: html
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) sets split_0 split_1 split_2 split_3 split_4
0 5.1 3.5 1.4 0.2 cvs train test train train train
1 4.9 3.0 1.4 0.2 cvs train train train train test
2 4.7 3.2 1.3 0.2 cvs test train train train train
3 4.6 3.1 1.5 0.2 cvs train train train train test
4 5.0 3.6 1.4 0.2 cvs train train train train test
... ... ... ... ... ... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 cvs train train test train train
146 6.3 2.5 5.0 1.9 cvs train test train train train
147 6.5 3.0 5.2 2.0 hos None None None None None
148 6.2 3.4 5.4 2.3 cvs train train train test train
149 5.9 3.0 5.1 1.8 cvs test train train train train

150 rows × 10 columns



.. GENERATED FROM PYTHON SOURCE LINES 233-236 Lets split in HOS and CVS using the class .. note:: This might not be working properly! .. GENERATED FROM PYTHON SOURCE LINES 236-244 .. code-block:: default :lineno-start: 237 # Divide in HOS and CSV. df = DataframeHOSCSVSplitter().split(dataframe) if TERMINAL: print("\nData") print(df) df .. raw:: html
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) sets split_0 split_1 split_2 split_3 split_4
0 5.1 3.5 1.4 0.2 CVS train test train train train
1 4.9 3.0 1.4 0.2 CVS train train train train test
2 4.7 3.2 1.3 0.2 CVS test train train train train
3 4.6 3.1 1.5 0.2 CVS train train train train test
4 5.0 3.6 1.4 0.2 CVS train train train train test
... ... ... ... ... ... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 CVS train train test train train
146 6.3 2.5 5.0 1.9 CVS train test train train train
147 6.5 3.0 5.2 2.0 CVS None None None None None
148 6.2 3.4 5.4 2.3 CVS train train train test train
149 5.9 3.0 5.1 1.8 CVS test train train train train

150 rows × 10 columns



.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.033 seconds) .. _sphx_glr_download__examples_scikits_plot_data_splitters.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_data_splitters.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_data_splitters.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_