I am ploting 15 images after training my VAE model but it generates the above error. The code is as following
n = 15 # figure with 15x15 digits
digit_size = 32
figure = np.zeros((digit_size * n, digit_size * n))
# We will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size, 3)
figure[i * digit_size: (i 1) * digit_size,
j * digit_size: (j 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()
I know that my model predicts image size (32 * 32 * 3) from 3072 latent space but I am giving just (32 * 32) here which is why it generates this error but I dont know how to generate (32 * 32 * 3) from below part.
figure[i * digit_size: (i 1) * digit_size,
j * digit_size: (j 1) * digit_size] = digit
Thanks
CodePudding user response:
I can't test it but as for me all problem is that you generate wrong figure
as start.
It needs third dimension with size 3
figure = np.zeros((digit_size * n, digit_size * n, 3))
CodePudding user response:
Try if this works, I can't check myself since you did not provide a complete code (with VAE model generator)
n = 15 # figure with 15x15 digits
digit_size = 32
figure = np.zeros((digit_size * n, digit_size * n, 3))
# We will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size, 3)
figure[i * digit_size: (i 1) * digit_size,
j * digit_size: (j 1) * digit_size,
: ] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()