I have a TensorFlow model that I have loaded from a repository as
model = tf.saved_model.load(folder)
My objective is to replicate this same model in Jax, and for so I need to understand whether the variable values (weights and biases) loaded are the correct ones.
One way I can recover the value of variable i
is just
vars = model.variables
print(vars[i].numpy())
If I assign
these values into the Jax network, however, I do not recover the right results, so in order to debug, I am trying to analyze the output of specific layers. To do so I need to make sure that weights and biases are the same, eg by assigning them previously. Specifically, if I do
numpy_vars = [v.numpy() for v in vars] # This is done in eager mode.
with tf.compat.v1.Session(graph = graph) as sess:
tvars = tf.compat.v1.trainable_variables()
tf.compat.v1.variables_initializer(vars).run() #Necessary init. of either tvars/vars
for v, tv in zip(numpy_vars, tvars):
tv.assign(v)
print(tvars[0].eval()) # This returns the value of the variable in graph mode.
print('------------------------------')
print(numpy_vars[0])
It seems to not be returning the same value, which I expected, although both have the same shape. I am wondering whether this might be because there are initialization operations in the model.graph
, but am not quite sure. If I instead change the line
tv.assign(v)
with
sess.run(tv.assign(v))
I get error
TypeError: Argument `fetch` = <tf.Variable 'UnreadVariable' shape=(11, 256) dtype=float32> has invalid type "_UnreadVariable" must be a string or Tensor. (Can not convert a _UnreadVariable into a Tensor or Operation.)
Any suggestions of how to assign the values of those variables so that they remain fixed during graph execution?
CodePudding user response:
The answer seems to be this:
numpy_vars = [v.numpy() for v in vars]
with tf.compat.v1.Session(graph = graph) as sess:
tvars = tf.compat.v1.trainable_variables()
tf.compat.v1.variables_initializer(vars).run()
print(tvars[0].eval())
print('------------------------------')
for v, tv in zip(numpy_vars, tvars):
tf.compat.v1.assign(tv, v).read_value().eval()
print(tvars[0].eval())
print('------------------------------')
print(numpy_vars[0])
After the line
tf.compat.v1.assign(tv, v).read_value().eval()
I have checked that the weights and biases work appropriately.