Home > OS >  Using plt to display pytorch image
Using plt to display pytorch image

Time:11-12

x is the image, yis the label, and metadata are dates, times etc.

for x, y_true, metadata in train_loader:
  print(x.shape)

The shape returns:

torch.Size([16, 3, 448, 448])

How do I go about displaying x as an image? Do I use plt?

CodePudding user response:

Your x is not a single image, but rather a batch of 16 different images, all of size 448x448 pixels.

You can use torchvision.utils.make_grid to convert x into a grid of 4x4 images, and then plot it:

import torchvision

with torch.no_grad():  # no need for gradients here
  grid = torchvision.utils.make_grid(x, nrow=4)  # you might consider normalize=True 
  # convert the grid into a numpy array suitable for plt
  grid_np = grid.cpu().numpy().transpose(1, 2, 0)  # channel dim should be last
  plt.matshow(grid_np)
  • Related