Home > Net >  Add second and third channel to a tensor - Pytorch
Add second and third channel to a tensor - Pytorch

Time:10-04

I am trying to apply mask to an image. I have my image shape [360, 480, 3] and my mask of shape [360, 480, 1]. How do I create a mask of same shape as my image in Pytorch? Also, in this case would the Green and Blue channel have zeros as elements or the same value as in channel Red? Thanks

CodePudding user response:

Logically, you should have the red-mask replicated for green and blue channel. PyTorch provides easy way of doing it

# repeat the last dimension thrice
mask_three_channel = mask.repeat(1, 1, 3)

then mask you image X (of shape [360, 480, 3]) as

X_masked = X * mask_three_channel

CodePudding user response:

When broadcasting a mask you should use torch.expand, and not torch.repeat to avoid copying unnecessary data. Indeed torch.expand will return a view, not a copy.

However in this case no need to expand the mask as the broadcasting is done by the operator:

>>> rgb = torch.rand(360, 480, 3)
>>> mask = torch.randint(0, 2, (360, 480, 1))

>>> rgb_masked = rgb*mask
>>> rgb_masked.shape
(360, 480, 3)

The resulting tensor contains [0,0,0] at pixel positions which equal 0 in mask.

  • Related