Home > Mobile >  Pytorch: How plot the prediction output from segmentation task when batch size is larger than 1?
Pytorch: How plot the prediction output from segmentation task when batch size is larger than 1?

Time:08-30

I have a doubt and a question about plot output of different batches in segmentation subject.

The snippet below plot the probability of each class and the prediction output.

I am sure the prob plots is plotting one batch, but not sure about prediction when I got the torch.argmax(outputs, 1). Am I plotted the argmax of one batch while the output of the network has the size of [10,4,256,256].

Also, I am wondering how can I plot the prediction of all batches while my batch size is 10.

outputs = model(t_image)

fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(nrows=1, ncols=5, sharex=True, sharey=True, figsize=(6,6))

img1 = ax1.imshow(torch.exp(outputs[0,0,:,:]).detach().cpu(), cmap = 'jet')
ax1.set_title("prob class 0")

img2 = ax2.imshow(torch.exp(outputs[0,1,:,:]).detach().cpu(), cmap = 'jet')
ax2.set_title("prob class 1")

img3 = ax3.imshow(torch.exp(outputs[0,2,:,:]).detach().cpu(), cmap = 'jet')
ax3.set_title("prob class 2")

img4 = ax4.imshow(torch.exp(outputs[0,3,:,:]).detach().cpu(), cmap = 'jet')
ax4.set_title("prob class 3")

img5 = ax5.imshow(torch.argmax(outputs, 1).detach().cpu().squeeze(), cmap = 'jet')
ax5.set_title("predicted")

CodePudding user response:

Not sure about what you are asking. Assuming you are using the NCHW data layout, your output is 10 samples per batch, 4 channels (each channel for a different class), and 256x256 resolution, then the first 4 graphs are plotting the class scores of the four classes.

For the 5th plot, your torch.argmax(outputs, 1).detach().cpu().squeeze() would give you a 10x256x256 image, which is the class prediction results for all 10 images in the batch, and matplotlib cannot properly plot it directly. So you would want to do torch.argmax(outputs[0,:,:,:], 0).detach().cpu().squeeze() which would get you a 256x256 map, which you can plot.

Since the result would range from 0 to 3 which represents the 4 classes, (and may be displayed as a very dim image), usually people would use a palette to color the plots. An example is provided here and looks like the cityscapes_map[p] line in the example.

For plotting all 10, why not write a for loop:

for i in range(outputs.size(0)):
    # do whatever you do with outputs[i, ...]
    # ...
    plt.show()

and go over each result in the batch one by one. There is also the option to have 10 rows in your subplot, if your screen is big enough.

  • Related