x
is the image, y
is the label, and metadata
are dates, times etc.
for x, y_true, metadata in train_loader:
print(x.shape)
The shape returns:
torch.Size([16, 3, 448, 448])
How do I go about displaying x as an image? Do I use plt?
CodePudding user response:
Your x
is not a single image, but rather a batch of 16 different images, all of size 448x448 pixels.
You can use torchvision.utils.make_grid
to convert x
into a grid of 4x4 images, and then plot it:
import torchvision
with torch.no_grad(): # no need for gradients here
grid = torchvision.utils.make_grid(x, nrow=4) # you might consider normalize=True
# convert the grid into a numpy array suitable for plt
grid_np = grid.cpu().numpy().transpose(1, 2, 0) # channel dim should be last
plt.matshow(grid_np)