Home > database >  Numpy: np.where() with new dimensions
Numpy: np.where() with new dimensions

Time:11-01

Python/numpy beginner so this should be easy to solve. Given a numpy 2d array of floats map, e.g.

map = [[0.19982308 0.19982308 0.19986019 ... 0.25456086 0.25463998 0.25463998]
      [0.19982308 0.19982308 0.19986019 ... 0.25456086 0.25463998 0.25463998]
      [0.19998285 0.19998285 0.20000038 ... 0.25459546 0.25466287 0.25466287]
      ...
      [0.4762167  0.4762167  0.47602317 ... 0.45300224 0.4541465  0.4541465 ]
      [0.4767613  0.4767613  0.47632453 ... 0.45406988 0.45538843 0.45538843]
      [0.4767613  0.4767613  0.47632453 ... 0.45406988 0.45538843 0.45538843]]

I want to carry out this operation:

new_map = np.where(map > 0.4, [255,255,255], [0,0,0])

That is, I want to create a new 2d array of the same dimensions but with RGB values instead of floats. Which RGB value is assigned to new_map[x][y] - white = [255,255,255] or black = [0,0,0] - is determined by whether map[x][y] is above a threshold (0.4 in the case above).

I get the following error message: operands could not be broadcast together with shapes (512,512) (3,) (3,)

I think I understand why - np.where restricts to the dimensions of map and I'm in effect trying to increase those dimensions by substituting the float for a nested array of length three.

Is there a workaround for this issue using where or any other numpy operation? Thanks!

CodePudding user response:

transform map to a np array first:

import numpy as np
#import matplotlib.pyplot as plt

#create map
map = np.random.rand(200,200)

#to show
#plt.matshow(map)


output_var = np.zeros([*map.shape,3])

output_var[map>0.4]=np.array([255,255,255])

CodePudding user response:

There is a broadcast issue because numpy.where supposes that the arrays have compatible shape. Supposing that the expected output shape is (y, x, 3) you can do:

map_reshape = np.expand_dims(map, -1)
white = np.array([255, 255, 255]).reshape(1, 1, -1)
black = np.array([0, 0, 0]).reshape(1, 1, -1)
new_map = np.where(map_reshape >  0.4, white, black)

If the shape of map is (y, x) the shape of new_map will be (y, x, 3)

  • Related