Say I have a 2x3 ndarray:
[[0,1,1],
[1,1,1]]
I want to replace {any row that has 0 in the first index} with [0,0,0]
:
[[0,0,0],
[1,1,1]]
Is it possible to do this with np.where? Here's my attempt:
import numpy as np
arr = np.array([[0,1,1],[1,1,1]])
replacement = np.full(arr.shape,[0,0,0])
new = np.where(arr[:,0]==0,replacement,arr)
I'm met with the following error at the last line:
ValueError: operands could not be broadcast together with shapes (2,) (2,3) (2,3)
The error makes sense, but I don't know how to fix the code to accomplish my goal. Any advice would be greatly appreciated!
Edit: I was trying to simplify a higher-dimensional case, but turns out it might not generalize.
If I have this ndarray:
[[[0,1,1],[1,1,1],[1,1,1]],
[[1,1,1],[1,1,1],[1,1,1]],
[[1,1,1],[1,1,1],[1,1,1]]]
how can I replace the first triplet with [0,0,0]
?
CodePudding user response:
Simple indexing/broadcasting will do:
a[a[:,0]==0] = [0,0,0]
output:
array([[0, 0, 0],
[1, 1, 1]])
explanation:
# get first column
a[:,0]
# array([0, 1])
# compare to 0 creating a boolean array
a[:,0]==0
# array([ True, False])
# select rows where the boolean is True
a[a[:,0]==0]
# array([[0, 1, 1]])
# replace those rows with new array
a[a[:,0]==0] = [0,0,0]
using np.where
this is less elegant in my opinion:
a[np.where(a[:,0]==0)[0]] = [0,0,0]
Edit: generalization
input:
a = np.arange(3**3).reshape((3,3,3))
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
transformation:
a[a[...,0]==0] = [0,0,0]
array([[[ 0, 0, 0],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])