Home > Back-end >  What is an efficient way to replace all values of a matrix except for those in rows or columns which
What is an efficient way to replace all values of a matrix except for those in rows or columns which

Time:10-15

I'm trying to replace all values of the input matrix X with np.nan except for the rows and columns which contain a value v:

enter image description here

import numpy as np
from numpy.typing import NDArray 


def get_masked_array(X: NDArray[float], v: float) -> NDArray[float]:
     # something
     return arr
    

# Input float array
X = np.array([[ 1.,  2.,  2.,  3.,  3.],
              [ 1.,  2.,  2.,  4.,  4.],
              [ 5.,  5.,  6.,  6.,  6.],
              [ 7.,  8.,  9.,  9.,  9.],
              [10., 10., 10., 10., 10.]])

Expected results:

>>> get_masked_array(X, 2.)
array([[ 1.,  2.,  2.,  3.,  3.],
       [ 1.,  2.,  2.,  4.,  4.],
       [nan,  5.,  6., nan, nan],
       [nan,  8.,  9., nan, nan],
       [nan, 10., 10., nan, nan]])

>>> get_masked_array(X, 3.)
array([[ 1.,  2.,  2.,  3.,  3.],
       [nan, nan, nan,  4.,  4.],
       [nan, nan, nan,  6.,  6.],
       [nan, nan, nan,  9.,  9.],
       [nan, nan, nan, 10., 10.]])

CodePudding user response:

Check row and column condition first and then combine them into a boolean condition taking advantage of numpy broadcasting:

v = 2

eq_v = X == v
has_v = eq_v.any(0) | eq_v.any(1, keepdims=True)  # check if row or col has v
np.where(has_v, X, np.nan)

array([[ 1.,  2.,  2.,  3.,  3.],
       [ 1.,  2.,  2.,  4.,  4.],
       [nan,  5.,  6., nan, nan],
       [nan,  8.,  9., nan, nan],
       [nan, 10., 10., nan, nan]])

Using walrus operator (>= python 3.8), you can do this in one line as efficiently:

np.where((eq_v := (X == v)).any(0) | eq_v.any(1, keepdims=True), X, np.nan)
array([[ 1.,  2.,  2.,  3.,  3.],
       [ 1.,  2.,  2.,  4.,  4.],
       [nan,  5.,  6., nan, nan],
       [nan,  8.,  9., nan, nan],
       [nan, 10., 10., nan, nan]])

CodePudding user response:

This'll work but I'm sure there's an easier way:

v = 2
rows, cols = map(np.unique, np.argwhere(X == v).T)
N = np.empty(X.shape, float)
N[:] = np.nan
N[rows, :] = X[rows, :]
N[:, cols] = X[:, cols]

Output:

array([[ 1.,  2.,  2.,  3.,  3.],
       [ 1.,  2.,  2.,  4.,  4.],
       [nan,  5.,  6., nan, nan],
       [nan,  8.,  9., nan, nan],
       [nan, 10., 10., nan, nan]])

For v = 3:

array([[ 1.,  2.,  2.,  3.,  3.],
       [nan, nan, nan,  4.,  4.],
       [nan, nan, nan,  6.,  6.],
       [nan, nan, nan,  9.,  9.],
       [nan, nan, nan, 10., 10.]])
  • Related