Home > Software engineering >  How to get attention weights from attention neural network?
How to get attention weights from attention neural network?


I have a model that uses an attention mechanism as below:

def create_model(feature_size, max_features, num_class):
    feature_input = Input((max_features,feature_size), dtype=tf.float32)
    feature_vectors = TimeDistributed(Dense(feature_size, use_bias=False, activation='tanh'))(feature_input)
    # Attention Layer
    attention_vectors = Dense(1,)(feature_vectors)
    attention_weights = Softmax()(attention_vectors)
    # Generating code vectors
    text_vectors = K.sum(feature_vectors * attention_weights, axis=1)
    # Prediction layer
    output_class = Dense(num_class, use_bias=False, activation='softmax')(text_vectors)
    model = Model(inputs=feature_input, outputs=output_class)
    return model

The training and testing codes are given below:

model = create_model(feature_size, max_features, num_class)
#compile model
# check summary of model
# Early stopping
earlystopping = callbacks.EarlyStopping(monitor ="val_loss",
                                    mode ="min", patience = 20,
                                    restore_best_weights = True)
# train model
model.fit(x=X_train, y=Y_train,batch_size=64,epochs=200,validation_data=(X_test, Y_test), callbacks =[earlystopping])
# Performance
predicted = model.predict(x=X_test)

Here, the input has the dimension (batch_size, max_feature, feature_size). max_feature is the number of features for a given input and each feature is a vector of feature_size length. After that, I have calculated attention weights for each feature which is later used to calculate one single vector (text_vectors) with a weighted sum using these attention weights (feature_vectors X attention weights). After training the model, I want to have the attention weights for each test data point which were used to calculate the text_vector for that test data input. How can I achieve that?

I have already seen several SO posts such as in this Sample

  • Related