Home > database >  Selecting From multidimensional Numpy array with multidimensional mask
Selecting From multidimensional Numpy array with multidimensional mask

Time:05-14

I am trying to build an example to understand image segmentation, you are given an image of shape (1,2,2,3) it's a 2x2 image where each pixel has 3 numbers indicating the probability that this pixel is belonging to a specific class. what I want is to have output (1,2,2,1) which is the class of each pixel based on the probability value (select the highest) code below to show the problem.

np.random.seed(2)
pixel=np.random.random((1,2,2,3)) #the image
pixel[0,0,1,:] #pixel(1,1) #with three classes probability  it should belong to class 0

output

array([0.43532239, 0.4203678 , 0.33033482]) #0 is the heighest

The mask i created

mask=pixel.argmax(-1)
mask=mask[...,np.newaxis] #shape (1,2,2,1)

Now I have the Image and the Mask but I don't know how to select with it

Please provide me with how to solve this and where I can learn this kind of slicing and select in NumPy.

You can think about it

Input :(1,2,2,3) Image with three classes for each pixel and (1,2,2,1) the class to select

Output :(1,2,2,1) Image with only one class for each pixel

CodePudding user response:

I think you want to use:

np.take_along_axis(pixel, mask, axis = -1)

You would also be able to get the same result without the use of a mask by using:

pixel.max(axis = -1, keepdims = True)
  • Related