I'm trying to extract the model weights from a saved model in a .pb
file. However, when I run sess it returns the model weights in bytes and I cannot read it. My code follows:
constant_values = {}
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
meta_graph = tf.compat.v1.saved_model.loader.load(sess,[tf.compat.v1.saved_model.tag_constants.SERVING],'model_2/1/')
tf.import_graph_def(meta_graph.graph_def, name='')
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
x=0
for constant_op in constant_ops:
x = constant_op.outputs[0]
value = sess.run(constant_op.outputs[0])
constant_values[constant_op.name] = valu
print(constant_op.name, value)
Here is a piece of what it returns:
b'\n\x1b\n\t\x08\x01\x12\x05model\n\x0e\x08\x02\x12\nsignatures\n\xe2\x01\n\x18\x08\x03\x12\x14layer_with_weights-0\n\x0b\x08\x03\x12\x07layer-0\n\x0b\x08\x04\x12\x07layer-1\n\x18\x08\x05\x12\x14layer_with_weights-1\n\x0b\x08\x05\x12\x07layer-2\n\r\x08\x06\x12\tvariables\n\x17\x08\x07\x12\x13trainable_variables\n\x19\x08\x08\x12\x15regularization_losses\n\r\x08\t\x12\tkeras_api\n\x0e\x08\n\x12\nsignatures\n#\x08\x0b\x12\x1f_self_saveable_object_factories\n\x00\n\x92R\n\x0b\x08\x0c\x12\x07layer-0\n\x0b\x08\r\x12\x07layer-1\n\x18\x08\x0e\x12\x14layer_with_weights-0\n\x0b\x08\x0e\x12\x07layer-2\n\x0b\x08\x0f\x12\x07layer-3\n\x18\x08\x10\x12\x14layer_with_weights-1\n\x0b\x08\x10\x12\x07layer-4\n\x18\x08\x11\x12\x14layer_with_weights-2...
Thanks
CodePudding user response:
Are you sure that the constant variables with model weights in your graph are named 'Const'?
If you just copied this code from a tutorial on how to get the model weights elsewhere - as I have seen in the past - try the following:
Instead of constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
try constant_ops = [op for op in sess.graph.get_operations()]
and look at how all the tensors and operantions in the graph look like. You probably will discover that the weight nodes are named differently.
Best,