I created a classification model using LSTM and Attention (custom layer). The model trains fine with model.fit
or without using tf.function
. But it gives this error when I use tf.function
with gradient tape
ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function.
Attention Layer
class Attention(tf.keras.layers.Layer):
def __init__(self,**kwargs):
super().__init__(**kwargs)
def build(self,input_shape):
self.W=self.add_weight(name='attention_weight', shape=(input_shape[-1],1),
initializer='random_normal', trainable=True)
self.b=self.add_weight(name='attention_bias', shape=(input_shape[1],1),
initializer='zeros', trainable=True)
super().build(input_shape)
def call(self,x):
e = K.tanh(K.dot(x,self.W) self.b)
e = K.squeeze(e, axis=-1)
alpha = K.softmax(e)
alpha = K.expand_dims(alpha, axis=-1)
context = x * alpha
context = K.sum(context, axis=1)
return context
My model -
class TweetClassificationModel(tf.keras.models.Model):
def __init__(self,rnn_units,vocab_size,embedding_weights,emb_dim,input_len,
dropout=0.1,dense_units=100,**kwargs):
super().__init__(**kwargs)
self.emb = Embedding(vocab_size,emb_dim,weights=[embedding_weights])
self.lstm = LSTM(rnn_units,return_sequences=True)
self.dense = Dense(100,activation='relu')
self.dropout = Dropout(dropout)
self.out_put = Dense(1,activation='sigmoid')
self.input_len = input_len
def call(self,x):
x = self.emb(x)
x = self.lstm(x)
x = Attention()(x)
x = self.dense(x)
x = self.dropout(x)
x = BatchNormalization()(x)
return self.out_put(x)
def summary(self):
x = Input(shape=(self.input_len,))
m = Model(x,self.call(x))
return m.summary()
It referred me to this link. Wasn't of much help though
CodePudding user response:
@xdurch0 is right. You shouldn't be creating the layers inside the constructer function and then again in call()
function. As per the error, tf.Variable
is being created more than once, this is because you are initiating both the Attention
and BatchNormalization
layers as it is inside the call function. What you should be doing, is this -
def build(self,input_shape):
self.attn = Attention()
self.bn = BatchNormalization()
'''and all the rest of the layers too, the same way they have been defined in the constructer function'''
super().build(input_shape)
You can create layers inside the __init__()
function, but its recommended to use the build()
function instead, for this