I am working on a computer vision problem, and in an image preprocessing part I encountered a problem that I can not solve.
Let's say I have a pair of images - (image, mask)
.image
is a 3 channel image with shape (H,W,3)
, while mask
is a 1 channel image with shape (H,W,1)
.
What I'm trying to do is set all pixels on image
to 0 in positions where mask
is 0.
My first solution was using a double for loop and it worked.
for y in range(mask.shape[1]):
for x in range(mask.shape[2]):
if mask[:,y,x] == 0:
img[y,x,:] = 0
- Don't be confused with different indexing, in this solution
img
istorch.Tensor
, and I cast it all totorch.Tensor
when I return from a function which this block is a part of
However, it's too slow when I'm training my models, my batch loading hangs.
My next solution was np.logical_not
, but it returns an error due to different number of channels between img
and mask
.
img[np.logical_not(mask)] = 0
Results in
IndexError: boolean index did not match indexed array along dimension 2; dimension is 3 but corresponding boolean dimension is 1
I also tried
img[mask==0] = 0
Which results in the same error message as above.
How can I solve this without being too slow?
Thanks in advance!
CodePudding user response:
It works when mask shape is (H,W)
. So
img[mask.squeeze()==0] = 0
or
img[mask[...,0]==0] = 0
CodePudding user response:
You can squeeze the last dimension of your mask, as proposed in the other answer, or you can also use np.where
as follows:
import numpy as np
H, W, C = 8, 8, 3
img = np.random.rand(H, W, C)
mask = np.all(img < 0.5, axis=-1, keepdims=True)
print(mask.shape) # (8, 8, 1)
default = 0
img_masked = np.where(mask, default, img)
Plotting the result:
import matpotlib.pyplot as plt
fig, (ax_orig, ax_mask, ax_masked) = plt.subplots(ncols=3, sharey=True)
ax_orig.set_title("Original")
ax_orig.imshow(img)
ax_mask.set_title("Mask")
ax_mask.imshow(mask, cmap="gray")
ax_masked.set_title("Masked")
ax_masked.imshow(img_masked)
plt.show()