Couches d’attention (MultiHeadAttention)

Fonctions :

  • MultiHeadAttention()

    La couche MultiHeadAttention implémente le mécanisme d’attention multi-têtes, qui permet au modèle de se concentrer simultanément sur différentes parties de la séquence d’entrée. Elle calcule plusieurs « têtes » d’attention parallèles, puis concatène et projette leurs résultats pour capturer des relations complexes dans les données séquentielles.

    Importation :

    from tensorflow.keras.layers import MultiHeadAttention
    import numpy as np

    Attributs :

    Paramètre Type Description Valeur par défaut
    num_heads int Nombre de têtes d’attention parallèles.
    key_dim int Dimension de la clé dans chaque tête d’attention.
    value_dim int ou None Dimension de la valeur dans chaque tête (par défaut égale à key_dim). None
    dropout float Taux de dropout appliqué sur les poids d’attention. 0.0
    use_bias bool Indique si les projections linéaires utilisent un biais. True

    Exemple de code :

    from tensorflow.keras.layers import MultiHeadAttention
    import numpy as np
    
    # Données simulées : batch de 1 séquence de 4 vecteurs de dimension 8
    query = np.random.rand(1, 4, 8).astype(np.float32)
    value = np.random.rand(1, 4, 8).astype(np.float32)
    key = np.random.rand(1, 4, 8).astype(np.float32)
    
    # Couche MultiHeadAttention avec 2 têtes, dimension clé 4
    mha = MultiHeadAttention(num_heads=2, key_dim=4)
    
    # Calcul de l'attention
    output = mha(query=query, value=value, key=key)
    
    print("Shape sortie :", output.shape)  # (1, 4, 8)
    

    Explication du code :

    Importation de la couche
    La couche `MultiHeadAttention` est importée depuis Keras.
    Préparation des données
    Les tenseurs `query`, `key` et `value` sont des séquences simulées, chacun contenant 4 vecteurs de dimension 8 dans un batch de taille 1.
    Définition de la couche
    On crée une instance de `MultiHeadAttention` avec 2 têtes et une dimension de clé de 4.
    Calcul de l’attention
    La couche calcule les scores d’attention entre `query`, `key` et `value` en parallèle sur les têtes, puis concatène les résultats.
    Sortie
    Le tenseur de sortie garde la même forme que la séquence d’entrée, ici (1, 4, 8).