Home > Enterprise >  ValueError: could not broadcast input array from shape (32,32,3) into shape (32,32)
ValueError: could not broadcast input array from shape (32,32,3) into shape (32,32)

Time:06-13

I am ploting 15 images after training my VAE model but it generates the above error. The code is as following

n = 15  # figure with 15x15 digits
digit_size = 32
figure = np.zeros((digit_size * n, digit_size * n))
# We will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        x_decoded = decoder.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size, 3)
        figure[i * digit_size: (i   1) * digit_size,
           j * digit_size: (j   1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()

I know that my model predicts image size (32 * 32 * 3) from 3072 latent space but I am giving just (32 * 32) here which is why it generates this error but I dont know how to generate (32 * 32 * 3) from below part.

 figure[i * digit_size: (i   1) * digit_size,
        j * digit_size: (j   1) * digit_size] = digit

Thanks

CodePudding user response:

I can't test it but as for me all problem is that you generate wrong figure as start.

It needs third dimension with size 3

figure = np.zeros((digit_size * n, digit_size * n, 3))

CodePudding user response:

Try if this works, I can't check myself since you did not provide a complete code (with VAE model generator)

n = 15  # figure with 15x15 digits
digit_size = 32
figure = np.zeros((digit_size * n, digit_size * n, 3))
# We will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        x_decoded = decoder.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size, 3)
        figure[i * digit_size: (i   1) * digit_size,
               j * digit_size: (j   1) * digit_size,
               : ] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()
  • Related