Couches d’attention (MultiHeadAttention)
Fonctions :
-
MultiHeadAttention()
La couche MultiHeadAttention implémente le mécanisme d’attention multi-têtes, qui permet au modèle de se concentrer simultanément sur différentes parties de la séquence d’entrée. Elle calcule plusieurs « têtes » d’attention parallèles, puis concatène et projette leurs résultats pour capturer des relations complexes dans les données séquentielles.
Importation :
from tensorflow.keras.layers import MultiHeadAttention import numpy as np
Attributs :
Paramètre Type Description Valeur par défaut num_heads
int Nombre de têtes d’attention parallèles. — key_dim
int Dimension de la clé dans chaque tête d’attention. — value_dim
int ou None
Dimension de la valeur dans chaque tête (par défaut égale à key_dim
).None
dropout
float Taux de dropout appliqué sur les poids d’attention. 0.0
use_bias
bool Indique si les projections linéaires utilisent un biais. True
Exemple de code :
from tensorflow.keras.layers import MultiHeadAttention import numpy as np # Données simulées : batch de 1 séquence de 4 vecteurs de dimension 8 query = np.random.rand(1, 4, 8).astype(np.float32) value = np.random.rand(1, 4, 8).astype(np.float32) key = np.random.rand(1, 4, 8).astype(np.float32) # Couche MultiHeadAttention avec 2 têtes, dimension clé 4 mha = MultiHeadAttention(num_heads=2, key_dim=4) # Calcul de l'attention output = mha(query=query, value=value, key=key) print("Shape sortie :", output.shape) # (1, 4, 8)
Explication du code :
Importation de la couche
La couche `MultiHeadAttention` est importée depuis Keras.Préparation des données
Les tenseurs `query`, `key` et `value` sont des séquences simulées, chacun contenant 4 vecteurs de dimension 8 dans un batch de taille 1.Définition de la couche
On crée une instance de `MultiHeadAttention` avec 2 têtes et une dimension de clé de 4.Calcul de l’attention
La couche calcule les scores d’attention entre `query`, `key` et `value` en parallèle sur les têtes, puis concatène les résultats.Sortie
Le tenseur de sortie garde la même forme que la séquence d’entrée, ici (1, 4, 8).