Home > Blockchain >  TFlite: set_tensor() takes 3 positional arguments but 4 were given
TFlite: set_tensor() takes 3 positional arguments but 4 were given

Time:05-26

I've written a simple program to calculate a quadratic equation with Tensorflow. Now, I'd like to transform the code for running on the Coral Dev Board by using Tensorflow lite.

The following code shows the generation of tflite-file:

# Define and compile the neural network
model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')

# Provide the data
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)

# Generation TFLite Model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the TFLite-Model
with open('mobilenet_v2_1.0_224.tflite', 'wb') as f:
    f.write(tflite_model)

This code runs on the Coral Dev Board:

# Load TFLite model and allocate tensors.
interpreter = tflite.Interpreter(model_path="mobilenet_v2_1.0_224.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=np.float32)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], xs, ys)
...

The last codeline runs on error:

TypeError: set_tensor() takes 3 positional arguments but 4 were given

The output of 'input_details[0]['index']':

{'name': 'serving_default_dense_input:0',
 'index': 0,
 'shape': array([1, 1], dtype=int32),
 'shape_signature': array([-1,  1], dtype=int32),
 'dtype': <class 'numpy.float32'>,
 'quantization': (0.0, 0),
 'quantization_parameters':
     {'scales': array([], dtype=float32),
     'zero_points': array([], dtype=int32),
     'quantized_dimension': 0},
 'sparsity_parameters': {}
}

I' don't understand the cause of error. Has someone any idea?

CodePudding user response:

You error is the following. You are passing a dictionary, to your set_tensor method. That means when python, reads that line of code. It gives you a TypeError, since you are passing a interable with 2 concurrent values. So that is the why of you error!

Now to fix your code. First you need to understand that the set_tensor method, expects the index of the given tensor. What you are currently passing in the input_details[0]['index'] is something else entirely. What you want to pass is the index, of you tensor. Which is as your displayed data given by interpreter.get_input_details() showed is 0. Also you are supposed to define the index of only one of the given data. Either the test data or the train data, not both at the same time. So eliminate either one of the xs or ys variables. So just rewrite this line like this

interpreter.set_tensor(0, ys)

I hope this get right, usually is good to also take a look at documentation. So you understand what each method expects https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter#set_tensor

CodePudding user response:

My approach was wrong. In Xs are the X-values and in Ys are the Y-values (result values) of the quadratic equation. I was not aware that you cannot do training in Tflite. But thanks for the effort anyway.

  • Related