I have a doubt and a question about plot output of different batches in segmentation subject.
The snippet below plot the probability of each class and the prediction output.
I am sure the prob plots is plotting one batch, but not sure about prediction when I got the torch.argmax(outputs, 1). Am I plotted the argmax of one batch while the output of the network has the size of [10,4,256,256].
Also, I am wondering how can I plot the prediction of all batches while my batch size is 10.
outputs = model(t_image)
fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(nrows=1, ncols=5, sharex=True, sharey=True, figsize=(6,6))
img1 = ax1.imshow(torch.exp(outputs[0,0,:,:]).detach().cpu(), cmap = 'jet')
ax1.set_title("prob class 0")
img2 = ax2.imshow(torch.exp(outputs[0,1,:,:]).detach().cpu(), cmap = 'jet')
ax2.set_title("prob class 1")
img3 = ax3.imshow(torch.exp(outputs[0,2,:,:]).detach().cpu(), cmap = 'jet')
ax3.set_title("prob class 2")
img4 = ax4.imshow(torch.exp(outputs[0,3,:,:]).detach().cpu(), cmap = 'jet')
ax4.set_title("prob class 3")
img5 = ax5.imshow(torch.argmax(outputs, 1).detach().cpu().squeeze(), cmap = 'jet')
ax5.set_title("predicted")
CodePudding user response:
Not sure about what you are asking. Assuming you are using the NCHW data layout, your output is 10 samples per batch, 4 channels (each channel for a different class), and 256x256 resolution, then the first 4 graphs are plotting the class scores of the four classes.
For the 5th plot, your torch.argmax(outputs, 1).detach().cpu().squeeze()
would give you a 10x256x256 image, which is the class prediction results for all 10 images in the batch, and matplotlib cannot properly plot it directly. So you would want to do torch.argmax(outputs[0,:,:,:], 0).detach().cpu().squeeze()
which would get you a 256x256 map, which you can plot.
Since the result would range from 0 to 3 which represents the 4 classes, (and may be displayed as a very dim image), usually people would use a palette to color the plots. An example is provided here and looks like the cityscapes_map[p]
line in the example.
For plotting all 10, why not write a for loop:
for i in range(outputs.size(0)):
# do whatever you do with outputs[i, ...]
# ...
plt.show()
and go over each result in the batch one by one. There is also the option to have 10 rows in your subplot, if your screen is big enough.