I have a plot with multiple images. There are 20 rows and 3 columns. In each row the 1st column should be the original image, the 2nd column should be the mask and the 3rd column should be segmentation obtained using a neural network.
I tried using the below code. 2 rows can be plotted correctly. From 3rd row images overlap the 1st row.
f = plt.figure()
for j in range(3):
pred_y = y_pred[j]
pred_y = pred_y.reshape(256, 256)
pred_y = (pred_y > 0.1).astype(np.uint8)
f.add_subplot(j 1, 3, 1 )
plt.imshow(X_test[j, :, :])
f.add_subplot(j 1, 3, 2)
plt.imshow(y_test[j])
f.add_subplot(j 1, 3, 3)
plt.imshow(pred_y)
plt.show()
The obtained output is,
CodePudding user response:
The arguments to Figure.add_subplot
are num_rows
, num_cols
, n
, not row
, num_rows
, col
. You are therefore making the top row of plots in a 1x3, 2x3, 3x3 arrangement. Make your calls to add_subplot
like this:
f.add_subplot(20, 3, 3 * j 1)
f.add_subplot(20, 3, 3 * j 2)
f.add_subplot(20, 3, 3 * j 1)
A better way to generate your axes might be plt.subplots
, as suggested in the demo:
fig, ax = plt.subplots(20, 3, constrained_layout=True)
for i in range(20):
ax[i, 0].imshow(X_test[i])
ax[i, 1].imshow(y_test[j])
ax[i, 2].imshow((pred_y[j].reshape(256, 256) > 0.1).astype(np.uint8))