Home > Software design >  Replace values greater/lower than x only in some rows on numpy array
Replace values greater/lower than x only in some rows on numpy array

Time:08-28

Let's say I have a numpy array:

arr = np.random.rand(10, 3)    
array([[0.51619267, 0.69536865, 0.87577509],
       [0.82922802, 0.19907414, 0.97310683],
       [0.0961705 , 0.16849081, 0.9426728 ],
       [0.84950943, 0.7767404 , 0.38282767],
       [0.1204213 , 0.67595169, 0.46029065],
       [0.98476311, 0.59316099, 0.0877238 ],
       [0.31916443, 0.5374729 , 0.18507312],
       [0.99078367, 0.77783068, 0.05689834],
       [0.03063616, 0.9887299 , 0.41034183],
       [0.43509505, 0.11150762, 0.27512664]])

In this array I want to replace values lower than 0.1 with some other value (let's say 99), but only for those rows that contain the value larger than 0.98. I can find those rows like this:

selected_rows = np.any(arr > 0.98, axis=1)
array([False, False, False, False, False,  True, False,  True,  True,
       False])

But I'm not sure how to combine both conditions. I tried the following:

arr[selected_rows][arr[selected_rows] < 0.1] = 99

but it doesn't have any effect on original arr. Any ideas?

CodePudding user response:

You can use broadcasting with vectorial AND to combine the row selection and the 2D selection of arr<0.1:

arr[(arr<0.1) & selected_rows[:,None]] = 99

output (using NaN instead of 99 for clarity):

array([[0.51619267, 0.69536865, 0.87577509],
       [0.82922802, 0.19907414, 0.97310683],
       [0.0961705 , 0.16849081, 0.9426728 ],  # 0.096 remains as no value > 0.98
       [0.84950943, 0.7767404 , 0.38282767],
       [0.1204213 , 0.67595169, 0.46029065],
       [0.98476311, 0.59316099,        nan],
       [0.31916443, 0.5374729 , 0.18507312],
       [0.99078367, 0.77783068,        nan],
       [       nan, 0.9887299 , 0.41034183],
       [0.43509505, 0.11150762, 0.27512664]])

CodePudding user response:

My answer is not succinct as the one from mozway (which has a nice approach BTW), but works well on your problem. Here's the code:

def replace(a, val):
    if np.any(a > 0.98):
        indices = np.where(a < 0.1)[0]
        if indices.size:
            np.put_along_axis(a, indices, val, axis=None)
    return a


arr = [...]
np.apply_along_axis(replace, 1, arr, np.nan) 

The above code makes use of numpy.apply_along_axis to apply given function over specified axis (1 in this case). We are sending np.nan (as opposed to 99 just for clarity in the output) as a value to given function. Consult documentation to know more about sending args and kwargs to a given function.

The replace function is self-explanatory. It uses numpy.put_along_axis which takes an array, the indices to put values against, the value itself, and the axis.

Output:

array([[0.51619267, 0.69536865, 0.87577509],
       [0.82922802, 0.19907414, 0.97310683],
       [0.0961705 , 0.16849081, 0.9426728 ],
       [0.84950943, 0.7767404 , 0.38282767],
       [0.1204213 , 0.67595169, 0.46029065],
       [0.98476311, 0.59316099,        nan],
       [0.31916443, 0.5374729 , 0.18507312],
       [0.99078367, 0.77783068,        nan],
       [       nan, 0.9887299 , 0.41034183],
       [0.43509505, 0.11150762, 0.27512664]])
  • Related