SHAP Summary Plot : Visualisation globale de l’importance des variables

Le SHAP Summary Plot est une visualisation essentielle pour interpréter de manière globale un modèle de machine learning. Il permet de comprendre quelles variables influencent le plus les prédictions du modèle et comment elles les influencent.


Que montre ce graphique ?

Le summary plot combine deux niveaux d’information :


Pourquoi l’utiliser ?


Exemple d’interprétation

Si une variable comme age apparaît très haut dans le graphe :

Fonctions :

  • shap.summary_plot()

    La fonction shap.summary_plot() est utilisée pour afficher un graphique résumant l'importance des caractéristiques dans un modèle de machine learning en utilisant les valeurs de Shapley. Ce graphique montre à quel point chaque caractéristique contribue aux prédictions du modèle. Il peut être utilisé pour observer l'impact global des différentes caractéristiques sur les prédictions, ainsi que les interactions entre elles. Le graphique peut être affiché sous différentes formes, notamment un graphique en points (scatter) ou un graphique de type barplot pour une vue d'ensemble de l'importance des caractéristiques.

    Importation :

    import shap

    Attributs :

    Paramètre Type Valeur par défaut Description
    shap_values array-like None Les valeurs de Shapley calculées par l'explainer. Ce paramètre représente l'impact de chaque caractéristique sur les prédictions du modèle.
    features array-like, pandas DataFrame None Les données d'entrée qui ont été utilisées pour calculer les valeurs de Shapley. Cela peut être un tableau numpy ou un DataFrame pandas représentant les caractéristiques du jeu de données.
    max_display int 20 Le nombre maximal de caractéristiques à afficher sur le graphique. Par défaut, les 20 caractéristiques les plus importantes sont affichées.
    plot_type str 'dot' Le type de graphique à afficher. 'dot' est le type par défaut, où chaque point représente l'impact d'une caractéristique sur une prédiction individuelle. D'autres types de graphiques comme 'bar' peuvent être utilisés pour afficher un graphique de type barplot.
    color str 'coolwarm' La palette de couleurs utilisée pour le graphique. Par défaut, 'coolwarm' est utilisé pour colorier les points en fonction de leur valeur.

    Exemple de code :

    import shap
    import xgboost
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    
    # Charger un jeu de données d'exemple (Iris)
    X, y = load_iris(return_X_y=True)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Entraîner un modèle XGBoost
    model = xgboost.XGBClassifier()
    model.fit(X_train, y_train)
    
    # Créer un explainer SHAP pour le modèle XGBoost
    explainer = shap.TreeExplainer(model)
    
    # Calculer les valeurs SHAP pour l'ensemble de test
    shap_values = explainer.shap_values(X_test)
    
    # Visualiser l'impact des caractéristiques sur les prédictions
    shap.summary_plot(shap_values, X_test)

    Explication du code :

    1. Données et modèle : Le jeu de données Iris est chargé, puis divisé en ensembles d'entraînement et de test. Un modèle XGBoost est ensuite entraîné sur les données d'entraînement.

    2. Création de l'explainer : Un objet TreeExplainer est créé pour expliquer les prédictions du modèle XGBoost.

    3. Calcul des valeurs SHAP : Nous appelons explainer.shap_values() pour calculer les valeurs de Shapley sur l'ensemble de test. Ces valeurs permettent d'évaluer l'impact de chaque caractéristique sur la prédiction du modèle.

    4. Visualisation : Nous utilisons shap.summary_plot() pour afficher un graphique récapitulant l'importance de chaque caractéristique sur les prédictions du modèle. Le graphique affiché montre la distribution des valeurs de Shapley pour chaque caractéristique et leur impact sur les prédictions globales.

    Sortie attendue :

    • Un graphique interactif affichant les caractéristiques les plus importantes du modèle, où chaque point représente l'impact d'une caractéristique sur une prédiction. La couleur des points est généralement utilisée pour représenter la valeur de la caractéristique, et la distribution des points montre l'impact de cette caractéristique sur les prédictions du modèle.

      • Scatter plot (dot plot) : Par défaut, shap.summary_plot() affiche un graphique en points, où chaque point correspond à un échantillon du jeu de test, et l'axe des x représente la valeur de la caractéristique pour chaque échantillon. La couleur du point représente la valeur de la caractéristique (par exemple, rouge pour des valeurs élevées, bleu pour des valeurs faibles).

      • Bar plot : Si vous préférez afficher les caractéristiques par ordre d'importance, vous pouvez utiliser plot_type='bar', ce qui affichera un graphique en barres montrant la somme des valeurs de Shapley pour chaque caractéristique.

    Applications :

    • Interprétation des modèles complexes : Le summary_plot permet d'expliquer de manière visuelle et intuitive l'impact de chaque caractéristique sur les prédictions du modèle. Il aide à identifier les caractéristiques les plus importantes et à comprendre comment elles influencent les résultats.

    • Analyse de l'importance des caractéristiques : Le graphique montre non seulement l'importance des caractéristiques, mais aussi la distribution de leur influence, ce qui peut aider à comprendre les interactions entre les différentes caractéristiques et leur impact sur les prédictions.

    • Amélioration de la transparence du modèle : Cette visualisation est particulièrement utile dans des contextes où la transparence des décisions du modèle est cruciale, comme dans les domaines médical, financier, ou juridique.