I'm trying to build the networks presented in the following paper:
And the following error raise up:
var_list = self.embedder.trainable_variables self.recovery.trainable_variables
AttributeError: 'function' object has no attribute 'trainable_variables'
What I'm doing wrong?
The baseline code that I'm reproducing can be found here
CodePudding user response:
The problem is that embedder
and recovery
are not models with trainable_variables
. Those two functions simply return the output of the last layer. Maybe try something like this:
import tensorflow as tf
X = tf.keras.layers.Input(shape=[10, 10], batch_size=2, name='RealData')
def recovery():
model = tf.keras.Sequential([
tf.keras.layers.LSTM(10, return_sequences=True),
tf.keras.layers.LSTM(10, return_sequences=True),
tf.keras.layers.LSTM(10, return_sequences=True),
tf.keras.layers.Dense(10, activation='sigmoid', name='OUTPUT')
])
return model
def embedder():
model = tf.keras.Sequential([
tf.keras.layers.LSTM(10, return_sequences=True),
tf.keras.layers.LSTM(10, return_sequences=True),
tf.keras.layers.LSTM(10, return_sequences=True),
tf.keras.layers.Dense(10, activation='sigmoid')
])
return model
embedder_model = embedder()
H = embedder_model(X)
recovery_model = recovery()
X_tilde = recovery_model(H)
autoencoder = tf.keras.Model(inputs=X, outputs=X_tilde)
var_list = embedder_model.trainable_variables embedder_model.trainable_variables
tf.print(var_list[:2])
[[[0.343916416 0.310338378 0.34440577 ... 0.0633761585 0.0405358076 0.276733816]
[0.245998859 0.197870493 0.0333348215 ... -0.136249736 0.271893084 -0.0605607331]
[-0.290359527 0.240957797 0.117871583 ... 0.172593892 0.113803834 0.0506341457]
...
[0.15672195 -0.161336392 -0.13484776 ... 0.306486845 -0.0707859397 0.245753765]
[0.00567743182 0.181330919 0.206510961 ... 0.0141542256 0.205756843 -0.074064374]
[0.299010575 -0.236641362 0.272176802 ... 0.0658480823 0.04648754 -0.342863292]], [[0.224076748 -0.112819761 -0.114276126 ... -0.190908 -0.282466382 -0.0711786151]
[-0.0689174235 0.203702673 -0.248280779 ... -0.0145524191 0.202952 0.0797807127]
[0.0919017 0.108805738 -0.124872617 ... 0.26839748 0.21041657 0.251440644]
...
[-0.117122218 -0.0974424109 -0.17138055 ... 0.150875479 0.0454813093 0.0753096]
[-0.115990438 -0.360190183 -0.0988362879 ... -0.0655761734 0.11425022 0.0291871373]
[-0.00164104556 -0.0442082509 0.135109842 ... -0.182655513 -0.0121813752 0.0497299805]]]