Home > Net >  Error when using the numpy.where function
Error when using the numpy.where function

Time:09-21

I want to use the numpy.where function to check whether an element in an array is a certain string, like for example coffee and then returning a certain vector in places where this is true, and a different one in places where this is not the case.

However, I keep getting the error message saying operands could not be broadcast together with shapes (4,) (1,3) (1,3).

Is there some other way I can do this without using for loops too much (the question explicitly says i should not use them)?

lst_1 = np.array(["dog", "dog1", "dog2", "dog3"])
a = np.where(lst_1 == "dog", [[1,0,0]], [[0,0,0]])
print(a)

CodePudding user response:

Can be done as a one-liner:

out = np.array([[0,0,0], [1,0,0]])
idx = lst_1 == dog
out[idx.astype(np.int32)]

Alternatively avoiding casting:

np.take([[0,0,0],[1,0,0]], lst_1 == "dog", axis=0)

CodePudding user response:

If you want to do this without for loops, you can make use of lambda functions:

lst_1 = np.array(["dog", "dog1", "dog2", "dog3"])
a = list(map(lambda x: [1,0,0] if x=='dog' else [0,0,0], lst_1))

print(a)

> [[1, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]
  • Related