I have a custom loss function that is reporting an error before any real processing happens.
I have a y_train of dimension (2717, 5, 5, 6) and a batch size of 25 with constants S1=S2=5. All I do is tf.reshape to make sure I get the desired dimension of (25,5,5,6), then I want to extract one axis but its somehow not working properly.
@tf.function
def yolo_loss(y_true,y_pred):
#mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM)
lambda_noobj = 0.5
lambda_coord = 5
y_pred = tf.reshape(y_pred,[batch_size,S1,S2,C B*5])
y_true = tf.reshape(y_true,[batch_size,S1,S2,6])
exists_box = tf.reshape(y_true[...,0],[batch_size,S1,S2,1])
........
While the first reshape of y_true works perfectly fine I get an error for the exists_box line, to be precise:
exists_box = tf.reshape(y_true[...,0],[batch_size,S1,S2,1])
Node: 'Reshape_2'
Input to reshape is a tensor with 425 values, but the requested shape has 625
[[{{node Reshape_2}}]] [Op:__inference_train_function_44379]
The ellipsis in [...,0] should return me an object of size 25 *5 * 5 = 625 so I am confused why it says the object is of dimension 425. I also made sure that all arrays in y_train are of the same shape.
CodePudding user response:
It seems that the error is caused by the last batch of your y_train
dataset, which has shape (17, 5, 5, 6)
(17 * 5 * 5 * 1 = 425). This occurs because when tensorflow batches your data, the last batch contains all the remaining elements, number of whose does not have to be your specified batch_size
(in your case 25) - note that 2717 % 25 = 17.
There are two things you can do:
- drop the remainding elements from the dataset; use this option if you are okay with losing a few examples from your traning data; if you are using tf.data.Dataset object, this can be done by providing
drop_remainder=True
in thebatch
method:
dataset = dataset.batch(25, drop_remainder=True)
- change your loss function so that it can process input with different first dimension than 25; from your description it's not clear what your loss function does, so you'll have to figure this out by yourself.