I am trying to apply mask to an image. I have my image shape [360, 480, 3] and my mask of shape [360, 480, 1]. How do I create a mask of same shape as my image in Pytorch? Also, in this case would the Green and Blue channel have zeros as elements or the same value as in channel Red? Thanks
CodePudding user response:
Logically, you should have the red-mask replicated for green and blue channel. PyTorch provides easy way of doing it
# repeat the last dimension thrice
mask_three_channel = mask.repeat(1, 1, 3)
then mask you image X
(of shape [360, 480, 3]) as
X_masked = X * mask_three_channel
CodePudding user response:
When broadcasting a mask you should use torch.expand
, and not torch.repeat
to avoid copying unnecessary data. Indeed torch.expand
will return a view, not a copy.
However in this case no need to expand the mask as the broadcasting is done by the operator:
>>> rgb = torch.rand(360, 480, 3)
>>> mask = torch.randint(0, 2, (360, 480, 1))
>>> rgb_masked = rgb*mask
>>> rgb_masked.shape
(360, 480, 3)
The resulting tensor contains [0,0,0]
at pixel positions which equal 0
in mask
.