Home > OS >  Figures overlapping when plotting using Matplotlib
Figures overlapping when plotting using Matplotlib

Time:02-11

I have a plot with multiple images. There are 20 rows and 3 columns. In each row the 1st column should be the original image, the 2nd column should be the mask and the 3rd column should be segmentation obtained using a neural network.

I tried using the below code. 2 rows can be plotted correctly. From 3rd row images overlap the 1st row.

f = plt.figure()
for j in range(3):
    pred_y = y_pred[j]
    pred_y = pred_y.reshape(256, 256)
    pred_y = (pred_y > 0.1).astype(np.uint8)
    f.add_subplot(j   1, 3, 1 )
    plt.imshow(X_test[j, :, :])

    f.add_subplot(j   1, 3, 2)
    plt.imshow(y_test[j])
    
    f.add_subplot(j   1, 3, 3)
    plt.imshow(pred_y)
plt.show()

The obtained output is,

enter image description here

CodePudding user response:

The arguments to Figure.add_subplot are num_rows, num_cols, n, not row, num_rows, col. You are therefore making the top row of plots in a 1x3, 2x3, 3x3 arrangement. Make your calls to add_subplot like this:

f.add_subplot(20, 3, 3 * j   1)
f.add_subplot(20, 3, 3 * j   2)
f.add_subplot(20, 3, 3 * j   1)

A better way to generate your axes might be plt.subplots, as suggested in the demo:

fig, ax = plt.subplots(20, 3, constrained_layout=True)
for i in range(20):
    ax[i, 0].imshow(X_test[i])
    ax[i, 1].imshow(y_test[j])
    ax[i, 2].imshow((pred_y[j].reshape(256, 256) > 0.1).astype(np.uint8))
  • Related