Home > OS >  Numpy - Make all pixel on target image 0 where mask is 0
Numpy - Make all pixel on target image 0 where mask is 0

Time:10-10

I am working on a computer vision problem, and in an image preprocessing part I encountered a problem that I can not solve.

Let's say I have a pair of images - (image, mask).image is a 3 channel image with shape (H,W,3), while mask is a 1 channel image with shape (H,W,1).

What I'm trying to do is set all pixels on image to 0 in positions where mask is 0. My first solution was using a double for loop and it worked.

for y in range(mask.shape[1]):
        for x in range(mask.shape[2]):
            if mask[:,y,x] == 0:
                img[y,x,:] = 0
  • Don't be confused with different indexing, in this solution img is torch.Tensor, and I cast it all to torch.Tensor when I return from a function which this block is a part of

However, it's too slow when I'm training my models, my batch loading hangs.

My next solution was np.logical_not, but it returns an error due to different number of channels between img and mask.

img[np.logical_not(mask)] = 0

Results in

IndexError: boolean index did not match indexed array along dimension 2; dimension is 3 but corresponding boolean dimension is 1

I also tried

img[mask==0] = 0

Which results in the same error message as above.

How can I solve this without being too slow?

Thanks in advance!

CodePudding user response:

It works when mask shape is (H,W). So

img[mask.squeeze()==0] = 0

or

img[mask[...,0]==0] = 0

CodePudding user response:

You can squeeze the last dimension of your mask, as proposed in the other answer, or you can also use np.where as follows:

import numpy as np

H, W, C = 8, 8, 3
img = np.random.rand(H, W, C)
mask = np.all(img < 0.5, axis=-1, keepdims=True)
print(mask.shape)  # (8, 8, 1)

default = 0
img_masked = np.where(mask, default, img)

Plotting the result:

import matpotlib.pyplot as plt

fig, (ax_orig, ax_mask, ax_masked) = plt.subplots(ncols=3, sharey=True)

ax_orig.set_title("Original")
ax_orig.imshow(img)

ax_mask.set_title("Mask")
ax_mask.imshow(mask, cmap="gray")

ax_masked.set_title("Masked")
ax_masked.imshow(img_masked)

plt.show()

result

  • Related