Stratified K-fold
Fonctions :
-
StratifiedKFold()
La fonction StratifiedKFold() de la bibliothèque sklearn.model_selection est une variante de la fonction KFold() qui effectue la division des données en sous-ensembles (ou "folds") tout en maintenant la proportion des classes dans chaque fold, c'est-à-dire que chaque fold contient la même proportion de classes cibles que l'ensemble de données d'origine. Cela est particulièrement utile lorsque les classes sont déséquilibrées, afin d'assurer une répartition équilibrée des classes dans chaque sous-ensemble pour une évaluation plus fiable.
Importation :
from sklearn.model_selection import StratifiedKFold
Attributs :
Paramètre
Description
n_splits
Le nombre de "folds" (divisions) à réaliser. Par défaut, il est égal à 5. shuffle
Si True
, les données sont mélangées avant la division. Par défaut, c'estFalse
.random_state
Permet de fixer la graine pour la génération aléatoire des folds si shuffle=True
.Exemple de code :
import numpy as np from sklearn.model_selection import StratifiedKFold from sklearn.linear_model import LogisticRegression from sklearn.datasets import load_iris from sklearn.metrics import accuracy_score # Chargement du jeu de données Iris data = load_iris() X = data.data y = data.target # Initialisation du StratifiedKFold avec 5 folds skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) # Liste pour stocker les scores de chaque fold fold_accuracies = [] # Création du modèle model = LogisticRegression(max_iter=200) # Validation croisée stratifiée for train_index, test_index in skf.split(X, y): # Séparation des données en train et test pour ce fold X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index] # Entraînement du modèle model.fit(X_train, y_train) # Prédiction et évaluation y_pred = model.predict(X_test) accuracy = accuracy_score(y_test, y_pred) fold_accuracies.append(accuracy) # Affichage de la moyenne des accuracies sur tous les folds print(f'Mean accuracy: {np.mean(fold_accuracies):.4f}')
Explication du code :
-
Chargement des données :
-
Nous chargeons le jeu de données Iris avec
load_iris()
depuissklearn.datasets
.
-
-
Initialisation de StratifiedKFold :
-
Nous créons une instance de
StratifiedKFold
avec 5 splits,shuffle=True
pour mélanger les données, et une graine aléatoire fixée à 42 pour rendre les résultats reproductibles.
-
-
Boucle de validation croisée stratifiée :
-
Nous utilisons
skf.split(X, y)
pour diviser les données en indices de training et de test, en assurant que chaque fold contient la même proportion de classes que l'ensemble de données original. Cela est particulièrement important lorsque les classes sont déséquilibrées.
-
-
Entraînement et évaluation :
-
À chaque itération, nous entraînons un modèle de régression logistique sur les données d'entraînement et évaluons la précision sur les données de test.
-
-
Calcul des résultats :
-
Les résultats de chaque fold sont stockés dans la liste
fold_accuracies
, et à la fin, nous affichons la moyenne des précisions sur tous les folds.
-
Points importants :
-
Équilibrage des classes :
StratifiedKFold
garantit que la proportion de chaque classe dans chaque fold est représentative de l'ensemble de données complet. Cela est crucial lorsqu'on travaille avec des ensembles de données où certaines classes sont sous-représentées. -
Mélange des données : Comme avec
KFold
, vous pouvez mélanger les données avant de les diviser avecshuffle=True
. Cela permet de réduire le biais potentiel lié à l'ordre des données dans l'ensemble original. -
Répétabilité : Fixer un
random_state
permet d'assurer que les splits sont toujours les mêmes lors de différentes exécutions du code.
Applications courantes :
-
Validation croisée avec classes déséquilibrées : Lorsqu'on travaille sur des problèmes de classification avec des classes très déséquilibrées (par exemple, dans la détection de fraude ou de maladies rares),
StratifiedKFold
garantit que chaque fold a une répartition représentative des classes. -
Hyperparameter tuning : Comme pour
KFold
,StratifiedKFold
est souvent utilisé pour la validation croisée afin de trouver les meilleurs hyperparamètres pour un modèle tout en assurant que les classes sont bien représentées.
-