I am trying to define a custom DensNet. But, I am getting a weird error and mot understand why. The code is as follows:
def densenet(input_shape, n_classes, filters = 32):
#batch norm relu conv
def bn_rl_conv(x,filters,kernel=1,strides=1):
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2D(filters, kernel, strides=strides,padding = 'same')(x)
return x
def dense_block(x, repetition):
for _ in range(repetition):
y = bn_rl_conv(x, 4*filters)
y = bn_rl_conv(y, filters, 3)
x = concatenate([y,x])
return x
def transition_layer(x):
x = bn_rl_conv(x, K.int_shape(x)[-1] //2 )
x = AvgPool2D(2, strides = 2, padding = 'same')(x)
return x
inp = Input (input_shape)
x = Conv2D(64, 7, strides = 2, padding = 'same')(inp)
x = MaxPool2D(3, strides = 2, padding = 'same')(x)
for repetition in [2,4,6,4]:
d = dense_block(x, repetition)
x = transition_layer(d)
x = GlobalAveragePooling2D()(x)
output = Dense(n_classes, activation = 'softmax')(x)
model = Model(inp, output)
return model
input_shape = (1024,2,1)
num_classes = 24
model = densenet(input_shape,num_classes)
The error is stating the following:
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.7/dist-packages/keras/layers/normalization/batch_normalization.py in build(self, input_shape)
296 if not input_shape.ndims:
297 raise ValueError(
--> 298 f'Input has undefined rank. Received: input_shape={input_shape}.')
299 ndims = len(input_shape)
300
ValueError: Input has undefined rank. Received: input_shape=<unknown>.
Why am I getting this error? I have already indicated the input shape. How can I fix this issue?
CodePudding user response:
You are calling Input
layer incorrectly. You're passing input_shape
to the __call__()
method instead of the shape
parameter.
Change:
inp = Input (input_shape)
To:
inp = Input(shape=input_shape)