Home > Net >  Model weights from sess.run(() is returning the value in bytes. How can I change to value?
Model weights from sess.run(() is returning the value in bytes. How can I change to value?

Time:05-03

I'm trying to extract the model weights from a saved model in a .pbfile. However, when I run sess it returns the model weights in bytes and I cannot read it. My code follows:

constant_values = {}
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
    meta_graph = tf.compat.v1.saved_model.loader.load(sess,[tf.compat.v1.saved_model.tag_constants.SERVING],'model_2/1/')
    tf.import_graph_def(meta_graph.graph_def, name='')
    constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
    x=0
    for constant_op in constant_ops:
        x = constant_op.outputs[0]
        value =  sess.run(constant_op.outputs[0])
        constant_values[constant_op.name] = valu
        
        print(constant_op.name, value)

Here is a piece of what it returns:

b'\n\x1b\n\t\x08\x01\x12\x05model\n\x0e\x08\x02\x12\nsignatures\n\xe2\x01\n\x18\x08\x03\x12\x14layer_with_weights-0\n\x0b\x08\x03\x12\x07layer-0\n\x0b\x08\x04\x12\x07layer-1\n\x18\x08\x05\x12\x14layer_with_weights-1\n\x0b\x08\x05\x12\x07layer-2\n\r\x08\x06\x12\tvariables\n\x17\x08\x07\x12\x13trainable_variables\n\x19\x08\x08\x12\x15regularization_losses\n\r\x08\t\x12\tkeras_api\n\x0e\x08\n\x12\nsignatures\n#\x08\x0b\x12\x1f_self_saveable_object_factories\n\x00\n\x92R\n\x0b\x08\x0c\x12\x07layer-0\n\x0b\x08\r\x12\x07layer-1\n\x18\x08\x0e\x12\x14layer_with_weights-0\n\x0b\x08\x0e\x12\x07layer-2\n\x0b\x08\x0f\x12\x07layer-3\n\x18\x08\x10\x12\x14layer_with_weights-1\n\x0b\x08\x10\x12\x07layer-4\n\x18\x08\x11\x12\x14layer_with_weights-2...

Thanks

CodePudding user response:

Are you sure that the constant variables with model weights in your graph are named 'Const'?

If you just copied this code from a tutorial on how to get the model weights elsewhere - as I have seen in the past - try the following:

Instead of constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"] try constant_ops = [op for op in sess.graph.get_operations()] and look at how all the tensors and operantions in the graph look like. You probably will discover that the weight nodes are named differently.

Best,

  • Related