scikit-learn train-test-split

August 16, 2023

here's the source code from scikit-learn

  • class BaseShuffleSplit(_MetadataRequester, metaclass=ABCMeta):
        """Base class for ShuffleSplit and StratifiedShuffleSplit"""
    
        # This indicates that by default CV splitters don't have a "groups" kwarg,
        # unless indicated by inheriting from ``GroupsConsumerMixin``.
        # This also prevents ``set_split_request`` to be generated for splitters
        # which don't support ``groups``.
        __metadata_request__split = {"groups": metadata_routing.UNUSED}
    
        def __init__(
            self, n_splits=10, *, test_size=None, train_size=None, random_state=None
        ):
            self.n_splits = n_splits
            self.test_size = test_size
            self.train_size = train_size
            self.random_state = random_state
            self._default_test_size = 0.1
    
        def split(self, X, y=None, groups=None):
            X, y, groups = indexable(X, y, groups)
            for train, test in self._iter_indices(X, y, groups):
                yield train, test
    
        @abstractmethod
        def _iter_indices(self, X, y=None, groups=None):
            """Generate (train, test) indices"""
    
        def get_n_splits(self, X=None, y=None, groups=None):
            return self.n_splits
    
        def __repr__(self):
            return _build_repr(self)
    
    class ShuffleSplit(BaseShuffleSplit):
        """Random permutation cross-validator
    
        Yields indices to split data into training and test sets.
    
        Note: contrary to other cross-validation strategies, random splits
        do not guarantee that all folds will be different, although this is
        still very likely for sizeable datasets.
    
    
        Examples
        --------
        >>> import numpy as np
        >>> from sklearn.model_selection import ShuffleSplit
        >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [3, 4], [5, 6]])
        >>> y = np.array([1, 2, 1, 2, 1, 2])
        >>> rs = ShuffleSplit(n_splits=5, test_size=.25, random_state=0)
        >>> rs.get_n_splits(X)
        5
        >>> print(rs)
        ShuffleSplit(n_splits=5, random_state=0, test_size=0.25, train_size=None)
        >>> for i, (train_index, test_index) in enumerate(rs.split(X)):
        ...     print(f"Fold {i}:")
        ...     print(f"  Train: index={train_index}")
        ...     print(f"  Test:  index={test_index}")
        Fold 0:
          Train: index=[1 3 0 4]
          Test:  index=[5 2]
        Fold 1:
          Train: index=[4 0 2 5]
          Test:  index=[1 3]
        Fold 2:
          Train: index=[1 2 4 0]
          Test:  index=[3 5]
        Fold 3:
          Train: index=[3 4 1 0]
          Test:  index=[5 2]
        Fold 4:
          Train: index=[3 5 1 0]
          Test:  index=[2 4]
        >>> # Specify train and test size
        >>> rs = ShuffleSplit(n_splits=5, train_size=0.5, test_size=.25,
        ...                   random_state=0)
        >>> for i, (train_index, test_index) in enumerate(rs.split(X)):
        ...     print(f"Fold {i}:")
        ...     print(f"  Train: index={train_index}")
        ...     print(f"  Test:  index={test_index}")
        Fold 0:
          Train: index=[1 3 0]
          Test:  index=[5 2]
        Fold 1:
          Train: index=[4 0 2]
          Test:  index=[1 3]
        Fold 2:
          Train: index=[1 2 4]
          Test:  index=[3 5]
        Fold 3:
          Train: index=[3 4 1]
          Test:  index=[5 2]
        Fold 4:
          Train: index=[3 5 1]
          Test:  index=[2 4]
        """
    
        def __init__(
            self, n_splits=10, *, test_size=None, train_size=None, random_state=None
        ):
            super().__init__(
                n_splits=n_splits,
                test_size=test_size,
                train_size=train_size,
                random_state=random_state,
            )
            self._default_test_size = 0.1
    
        def _iter_indices(self, X, y=None, groups=None):
            n_samples = _num_samples(X)
            n_train, n_test = _validate_shuffle_split(
                n_samples,
                self.test_size,
                self.train_size,
                default_test_size=self._default_test_size,
            )
    
            rng = check_random_state(self.random_state)
            for i in range(self.n_splits):
                # random partition
                permutation = rng.permutation(n_samples)
                ind_test = permutation[:n_test]
                ind_train = permutation[n_test : (n_test + n_train)]
                yield ind_train, ind_test
    
    
    def _validate_shuffle_split(n_samples, test_size, train_size, default_test_size=None):
        """
        Validation helper to check if the test/test sizes are meaningful w.r.t. the
        size of the data (n_samples).
        """
        if test_size is None and train_size is None:
            test_size = default_test_size
    
        test_size_type = np.asarray(test_size).dtype.kind
        train_size_type = np.asarray(train_size).dtype.kind
    
        if (
            test_size_type == "i"
            and (test_size >= n_samples or test_size <= 0)
            or test_size_type == "f"
            and (test_size <= 0 or test_size >= 1)
        ):
            raise ValueError(
                "test_size={0} should be either positive and smaller"
                " than the number of samples {1} or a float in the "
                "(0, 1) range".format(test_size, n_samples)
            )
    
        if (
            train_size_type == "i"
            and (train_size >= n_samples or train_size <= 0)
            or train_size_type == "f"
            and (train_size <= 0 or train_size >= 1)
        ):
            raise ValueError(
                "train_size={0} should be either positive and smaller"
                " than the number of samples {1} or a float in the "
                "(0, 1) range".format(train_size, n_samples)
            )
    
        if train_size is not None and train_size_type not in ("i", "f"):
            raise ValueError("Invalid value for train_size: {}".format(train_size))
        if test_size is not None and test_size_type not in ("i", "f"):
            raise ValueError("Invalid value for test_size: {}".format(test_size))
    
        if train_size_type == "f" and test_size_type == "f" and train_size + test_size > 1:
            raise ValueError(
                "The sum of test_size and train_size = {}, should be in the (0, 1)"
                " range. Reduce test_size and/or train_size.".format(train_size + test_size)
            )
    
        if test_size_type == "f":
            n_test = ceil(test_size * n_samples)
        elif test_size_type == "i":
            n_test = float(test_size)
    
        if train_size_type == "f":
            n_train = floor(train_size * n_samples)
        elif train_size_type == "i":
            n_train = float(train_size)
    
        if train_size is None:
            n_train = n_samples - n_test
        elif test_size is None:
            n_test = n_samples - n_train
    
        if n_train + n_test > n_samples:
            raise ValueError(
                "The sum of train_size and test_size = %d, "
                "should be smaller than the number of "
                "samples %d. Reduce test_size and/or "
                "train_size." % (n_train + n_test, n_samples)
            )
    
        n_train, n_test = int(n_train), int(n_test)
    
        if n_train == 0:
            raise ValueError(
                "With n_samples={}, test_size={} and train_size={}, the "
                "resulting train set will be empty. Adjust any of the "
                "aforementioned parameters.".format(n_samples, test_size, train_size)
            )
    
        return n_train, n_test
  • the function _iter_indices handles the splitting:

    • n_samples = _num_samples(X): Calculates the total number of samples in the dataset X. This is necessary for generating random permutations of indices.

    • _validate_shuffle_split(...): This is a utility function that calculates the number of samples in the training and test sets based on the specified test_size and train_size. If test_size is not specified, it uses a default value.

    • rng = check_random_state(self.random_state): Initializes a random number generator (RNG) using the random_state specified during the initialization of the ShuffleSplit instance. If random_state is not provided, a random seed is used.

    • for i in range(self.n_splits): Iterates for each split/iteration specified by n_splits.

    • permutation = rng.permutation(n_samples): Generates a random permutation of indices from 0 to n_samples - 1. This random permutation will be used to create shuffled train-test splits.

    • ind_test = permutation[:n_test]: Selects the first n_test indices from the shuffled permutation as test indices.

    • ind_train = permutation[n_test : (n_test + n_train)]: Selects the next n_train indices from the shuffled permutation as train indices.

    • yield ind_train, ind_test: Yields the train and test indices for the current iteration, effectively generating a train-test split for that iteration.

  • _validate_shuffle_split:

    • test_size_type = np.asarray(test_size).dtype.kind
          train_size_type = np.asarray(train_size).dtype.kind
    • These lines determine the type of the test_size and train_size parameters using NumPy's dtype.kind attribute.

    • It helps differentiate between integer (i) and floating-point (f) types for test_size and train_size.

    • if (
              test_size_type == "i"
              and (test_size >= n_samples or test_size <= 0)
              or test_size_type == "f"
              and (test_size <= 0 or test_size >= 1)
          ):
              raise ValueError("...")
    • This block of code checks if the provided test_size is valid. It raises a ValueError if the conditions are not met.

    • For integer values of test_size, it ensures that test_size is greater than 0 and less than the total number of samples (n_samples).

    • For floating-point values, it ensures that test_size is within the range (0, 1), which means it represents a proportion of the data.

    • if train_size_type == "f" and test_size_type == "f" and train_size + test_size > 1:
              raise ValueError("...")
      • If both train_size and test_size are floats, this check ensures that their sum does not exceed 1. If it does, it raises an error.

      • This is to prevent specifying conflicting proportions for train and test sets.

    • if test_size_type == "f":
          n_test = ceil(test_size * n_samples)
      elif test_size_type == "i":
          n_test = float(test_size)
      
      if train_size_type == "f":
          n_train = floor(train_size * n_samples)
      elif train_size_type == "i":
          n_train = float(train_size)
      • If test_size is specified as a float (a proportion), the code calculates the number of test samples (n_test) by multiplying the specified proportion (test_size) with the total number of samples (n_samples) and rounding up (ceil) to ensure a sufficient number of samples.

      • If test_size is specified as an integer, it's directly converted to a float, indicating the exact number of test samples.

      • Similar calculations are performed for train_size, converting proportions or specified sample counts into actual sample counts for the training dataset.

      • By performing these calculations, the function ensures that the sizes of the test and train datasets are well-defined and based on actual sample counts. This conversion from proportions to counts is important because it allows the function to create meaningful train-test splits that align with the requirements of machine learning algorithms. It bridges the gap between the user's input (specified proportions) and the algorithm's input (required sample counts).

    • n_train, n_test = int(n_train), int(n_test)
      • These lines convert the calculated train and test sizes to integers, as sample counts must be whole numbers.
  class StratifiedShuffleSplit(BaseShuffleSplit):
      """Stratified ShuffleSplit cross-validator

      Provides train/test indices to split data in train/test sets.

      This cross-validation object is a merge of StratifiedKFold and
      ShuffleSplit, which returns stratified randomized folds. The folds
      are made by preserving the percentage of samples for each class.

      Note: like the ShuffleSplit strategy, stratified random splits
      do not guarantee that all folds will be different, although this is
      still very likely for sizeable datasets.
      """

      def __init__(
          self, n_splits=10, *, test_size=None, train_size=None, random_state=None
      ):
          super().__init__(
              n_splits=n_splits,
              test_size=test_size,
              train_size=train_size,
              random_state=random_state,
          )
          self._default_test_size = 0.1

      def _iter_indices(self, X, y, groups=None):
          n_samples = _num_samples(X)
          y = check_array(y, input_name="y", ensure_2d=False, dtype=None)
          n_train, n_test = _validate_shuffle_split(
              n_samples,
              self.test_size,
              self.train_size,
              default_test_size=self._default_test_size,
          )

          if y.ndim == 2:
              # for multi-label y, map each distinct row to a string repr
              # using join because str(row) uses an ellipsis if len(row) > 1000
              y = np.array([" ".join(row.astype("str")) for row in y])

          classes, y_indices = np.unique(y, return_inverse=True)
          n_classes = classes.shape[0]

          class_counts = np.bincount(y_indices)
          if np.min(class_counts) < 2:
              raise ValueError(
                  "The least populated class in y has only 1"
                  " member, which is too few. The minimum"
                  " number of groups for any class cannot"
                  " be less than 2."
              )

          if n_train < n_classes:
              raise ValueError(
                  "The train_size = %d should be greater or "
                  "equal to the number of classes = %d" % (n_train, n_classes)
              )
          if n_test < n_classes:
              raise ValueError(
                  "The test_size = %d should be greater or "
                  "equal to the number of classes = %d" % (n_test, n_classes)
              )

          # Find the sorted list of instances for each class:
          # (np.unique above performs a sort, so code is O(n logn) already)
          class_indices = np.split(
              np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1]
          )

          rng = check_random_state(self.random_state)

          for _ in range(self.n_splits):
              # if there are ties in the class-counts, we want
              # to make sure to break them anew in each iteration
              n_i = _approximate_mode(class_counts, n_train, rng)
              class_counts_remaining = class_counts - n_i
              t_i = _approximate_mode(class_counts_remaining, n_test, rng)

              train = []
              test = []

              for i in range(n_classes):
                  permutation = rng.permutation(class_counts[i])
                  perm_indices_class_i = class_indices[i].take(permutation, mode="clip")

                  train.extend(perm_indices_class_i[: n_i[i]])
                  test.extend(perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]])

              train = rng.permutation(train)
              test = rng.permutation(test)

              yield train, test
  • if dataset has imbalanced class proportions (i.e., some classes have many more samples than others), maintaining the class distribution ensures that each fold used for training and testing contains a representative sample of each class. This is important because it helps prevent situations where one or more classes are underrepresented or entirely missing in a particular fold. Without maintaining the class distribution, certain classes could be left out or inadequately represented in some of the folds, leading to biased or unreliable model evaluation.

  • _iter_indices :

    • _num_samples(X): Calculates the total number of samples in the dataset X.

    • check_array(y, input_name="y", ensure_2d=False, dtype=None): Checks and converts the target array y into a NumPy array, ensuring that it's a one-dimensional array.

    • _validate_shuffle_split(...): A utility function that calculates the number of samples in the training and test sets based on the specified test_size and train_size. If these are not provided, it uses a default test size.

    • If the target array y is two-dimensional (indicating a multi-label scenario), each row of y is converted into a string representation using join. This is done to handle multi-label classes properly.

    • np.unique(y, return_inverse=True): Returns unique classes in the target array y and also returns an array of indices that map each element of y to its corresponding class index.

    • np.bincount(y_indices): Counts the occurrences of each unique class in the target array y.

    • Check if the minimum count of any class is less than 2. If so, it raises an error because cross-validation requires at least two samples in each class.

    • Ensure that the specified train_size and test_size are greater than or equal to the number of classes. This ensures that each class can be represented in both training and test sets.

    • class_indices = np.split(
          np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1]
      )
      • Splits the indices of y based on the unique classes. This creates separate index arrays for each class.

      • ensure that, during the stratified shuffle-based cross-validation process, each fold maintains the same distribution of classes as the original dataset.

      • np.argsort(y_indices, kind="mergesort"): This part sorts the indices of the target array y based on the class index they belong to. It returns the indices that would sort y_indices in ascending order.

      • np.cumsum(class_counts)[:-1]: This part calculates the cumulative sum of class counts and then excludes the last element. The resulting array indicates the indices where the classes change in the sorted indices array from last step.

      • np.split(...): This function splits the sorted indices array into segments based on the positions where the classes change. This effectively groups the indices belonging to each class into separate arrays within the class_indices list.

    • _approximate_mode(...): A utility function that calculates how many samples to allocate to each class in the training and test sets.

    • for i in range(n_classes):
          permutation = rng.permutation(class_counts[i])
          perm_indices_class_i = class_indices[i].take(permutation, mode="clip")
      
          train.extend(perm_indices_class_i[: n_i[i]])
          test.extend(perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]])
      • For each class, shuffles the indices within the class to create random samples for both training and test sets.

      • Appends the shuffled indices to the train and test lists.

    • train = rng.permutation(train)
      test = rng.permutation(test)
      
      yield train, test
      • Shuffles the train and test lists to ensure randomness.

      • Yields the shuffled train and test indices for the current fold.

  • 
    @validate_params(
        {
            "test_size": [
                Interval(RealNotInt, 0, 1, closed="neither"),
                Interval(numbers.Integral, 1, None, closed="left"),
                None,
            ],
            "train_size": [
                Interval(RealNotInt, 0, 1, closed="neither"),
                Interval(numbers.Integral, 1, None, closed="left"),
                None,
            ],
            "random_state": ["random_state"],
            "shuffle": ["boolean"],
            "stratify": ["array-like", None],
        },
        prefer_skip_nested_validation=True,
    )
    def train_test_split(
        *arrays,
        test_size=None,
        train_size=None,
        random_state=None,
        shuffle=True,
        stratify=None,
    ):
        n_arrays = len(arrays)
        if n_arrays == 0:
            raise ValueError("At least one array required as input")
    
        arrays = indexable(*arrays)
    
        n_samples = _num_samples(arrays[0])
        n_train, n_test = _validate_shuffle_split(
            n_samples, test_size, train_size, default_test_size=0.25
        )
    
        if shuffle is False:
            if stratify is not None:
                raise ValueError(
                    "Stratified train/test split is not implemented for shuffle=False"
                )
    
            train = np.arange(n_train)
            test = np.arange(n_train, n_train + n_test)
    
        else:
            if stratify is not None:
                CVClass = StratifiedShuffleSplit
            else:
                CVClass = ShuffleSplit
    
            cv = CVClass(test_size=n_test, train_size=n_train, random_state=random_state)
    
            train, test = next(cv.split(X=arrays[0], y=stratify))
    
        return list(
            chain.from_iterable(
                (_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays
            )
        )
    • decorators

      • "test_size" and "train_size": The decorator specifies valid intervals for these parameters. It ensures that test_size and train_size can be either a float in the open interval (0, 1) or an integer greater than or equal to 1. This ensures that the proportions are meaningful and within appropriate ranges.

      • "random_state": This parameter is expected to be a valid random state object. This ensures that users provide a valid random state for controlling randomness in the data split.

      • "shuffle": This parameter is expected to be a boolean value. It ensures that users provide a valid boolean value to determine whether shuffling is enabled.

      • "stratify": This parameter can be an array-like object or None. It allows users to specify whether stratification should be applied during the data split.

    • n_samples = _num_samples(arrays[0])
      n_train, n_test = _validate_shuffle_split(
          n_samples, test_size, train_size, default_test_size=0.25
      )
      • These lines calculate the number of samples in the first array (assumed to be the feature matrix) using the _num_samples function. Then, it calls _validate_shuffle_split to calculate appropriate train and test sizes based on the specified test_size and train_size or their defaults. This ensures that valid and meaningful train-test split sizes are calculated.
    • if shuffle is False:
        if stratify is not None:
            raise ValueError(
                "Stratified train/test split is not implemented for shuffle=False"
            )
      train = np.arange(n_train)
      test = np.arange(n_train, n_train + n_test)
      • creates indices for the train and test sets without shuffling

      • When shuffling is disabled (shuffle=False), the goal is to create a deterministic and predictable split. Stratification becomes challenging without shuffling because the order of data remains fixed, which can lead to uneven class distributions in train and test sets. For example, if the data is ordered by class and you split without shuffling, one of the classes could end up entirely in the training set, while the other class is entirely in the test set.

    • else:
          if stratify is not None:
              CVClass = StratifiedShuffleSplit
          else:
              CVClass = ShuffleSplit
      
          cv = CVClass(test_size=n_test, train_size=n_train, random_state=random_state)
      
          train, test = next(cv.split(X=arrays[0], y=stratify))
      • These lines handle the case where shuffling is enabled (shuffle is True). If stratify is specified, it selects StratifiedShuffleSplit as the cross-validator class; otherwise, it uses ShuffleSplit. It then creates an instance of the selected cross-validator with the calculated train and test sizes, random state, and other settings. Finally, it uses the split method to generate indices for the train and test sets.
    • return list(
          chain.from_iterable(
              (_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays
          )
      )
      • These lines return the train and test sets as lists. It uses the _safe_indexing function to ensure that arrays are indexed safely without any out-of-bounds access. The chain.from_iterable function flattens the nested lists of train and test arrays from different input arrays.