Stratified K-fold
Fonctions :
-
StratifiedKFold()
La fonction StratifiedKFold() de la bibliothèque sklearn.model_selection est une variante de la fonction KFold() qui effectue la division des données en sous-ensembles (ou "folds") tout en maintenant la proportion des classes dans chaque fold, c'est-à-dire que chaque fold contient la même proportion de classes cibles que l'ensemble de données d'origine. Cela est particulièrement utile lorsque les classes sont déséquilibrées, afin d'assurer une répartition équilibrée des classes dans chaque sous-ensemble pour une évaluation plus fiable.
Importation :
from sklearn.model_selection import StratifiedKFoldAttributs :
Paramètre
Description
n_splitsLe nombre de "folds" (divisions) à réaliser. Par défaut, il est égal à 5. shuffleSi True, les données sont mélangées avant la division. Par défaut, c'estFalse.random_statePermet de fixer la graine pour la génération aléatoire des folds si shuffle=True.Exemple de code :
import numpy as np from sklearn.model_selection import StratifiedKFold from sklearn.linear_model import LogisticRegression from sklearn.datasets import load_iris from sklearn.metrics import accuracy_score # Chargement du jeu de données Iris data = load_iris() X = data.data y = data.target # Initialisation du StratifiedKFold avec 5 folds skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) # Liste pour stocker les scores de chaque fold fold_accuracies = [] # Création du modèle model = LogisticRegression(max_iter=200) # Validation croisée stratifiée for train_index, test_index in skf.split(X, y): # Séparation des données en train et test pour ce fold X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index] # Entraînement du modèle model.fit(X_train, y_train) # Prédiction et évaluation y_pred = model.predict(X_test) accuracy = accuracy_score(y_test, y_pred) fold_accuracies.append(accuracy) # Affichage de la moyenne des accuracies sur tous les folds print(f'Mean accuracy: {np.mean(fold_accuracies):.4f}')Explication du code :
-
Chargement des données :
-
Nous chargeons le jeu de données Iris avec
load_iris()depuissklearn.datasets.
-
-
Initialisation de StratifiedKFold :
-
Nous créons une instance de
StratifiedKFoldavec 5 splits,shuffle=Truepour mélanger les données, et une graine aléatoire fixée à 42 pour rendre les résultats reproductibles.
-
-
Boucle de validation croisée stratifiée :
-
Nous utilisons
skf.split(X, y)pour diviser les données en indices de training et de test, en assurant que chaque fold contient la même proportion de classes que l'ensemble de données original. Cela est particulièrement important lorsque les classes sont déséquilibrées.
-
-
Entraînement et évaluation :
-
À chaque itération, nous entraînons un modèle de régression logistique sur les données d'entraînement et évaluons la précision sur les données de test.
-
-
Calcul des résultats :
-
Les résultats de chaque fold sont stockés dans la liste
fold_accuracies, et à la fin, nous affichons la moyenne des précisions sur tous les folds.
-
Points importants :
-
Équilibrage des classes :
StratifiedKFoldgarantit que la proportion de chaque classe dans chaque fold est représentative de l'ensemble de données complet. Cela est crucial lorsqu'on travaille avec des ensembles de données où certaines classes sont sous-représentées. -
Mélange des données : Comme avec
KFold, vous pouvez mélanger les données avant de les diviser avecshuffle=True. Cela permet de réduire le biais potentiel lié à l'ordre des données dans l'ensemble original. -
Répétabilité : Fixer un
random_statepermet d'assurer que les splits sont toujours les mêmes lors de différentes exécutions du code.
Applications courantes :
-
Validation croisée avec classes déséquilibrées : Lorsqu'on travaille sur des problèmes de classification avec des classes très déséquilibrées (par exemple, dans la détection de fraude ou de maladies rares),
StratifiedKFoldgarantit que chaque fold a une répartition représentative des classes. -
Hyperparameter tuning : Comme pour
KFold,StratifiedKFoldest souvent utilisé pour la validation croisée afin de trouver les meilleurs hyperparamètres pour un modèle tout en assurant que les classes sont bien représentées.
-