Home > Net >  Setting some 2d array labels to zero in python
Setting some 2d array labels to zero in python

Time:04-28

My goal is to set some labels in 2d array to zero without using a for loop. Is there a faster numpy way to do this without the for loop? The ideal scenario would be temp_arr[labeled_im not in labels] = 0, but it's not really working the way I'd like it to.

labeled_array = np.array([[1,2,3],
                          [4,5,6],
                          [7,8,9]])

labels = [2,4,5,6,8]
temp_arr = np.zeros((labeled_array.shape)).astype(int)
for label in labels:
    temp_arr[labeled_array == label] = label

>> temp_arr
[[0 2 0]
 [4 5 6]
 [0 8 0]]

The for loop gets quite slow when there are a lot of iterations to go through, so it is important to improve the execution time with numpy.

CodePudding user response:

You can use define labels as a set and use temp_arr = np.where(np.isin(labeled_array, labels), labeled_array, 0). Although, the difference for such a small array does not seem to be significant.

import numpy as np
import time

labeled_array = np.array([[1,2,3],
                          [4,5,6],
                          [7,8,9]])

labels = [2,4,5,6,8]

start = time.time()
temp_arr_0 = np.zeros((labeled_array.shape)).astype(int)
for label in labels:
    temp_arr_0[labeled_array == label] = label
end = time.time()

print(f"Loop takes {end - start}")

start = time.time()
temp_arr_1 = np.where(np.isin(labeled_array, labels), labeled_array, 0)
end = time.time()

print(f"np.where takes {end - start}")

labels  = {2,4,5,6,8}

start = time.time()
temp_arr_2 = np.where(np.isin(labeled_array, labels), labeled_array, 0)
end = time.time()

print(f"np.where with set takes {end - start}")

outputs

Loop takes 5.3882598876953125e-05
np.where takes 0.00010514259338378906
np.where with set takes 3.314018249511719e-05

CodePudding user response:

In the case the labels are unique in labels (and memory isn't a concern), here's a way to go.

As the very first step, we convert labels to a ndarray

labels = np.array(labels)

Then, we produce two broadcastable arrays from labeled_array and labels

labeled_row = labeled_array.ravel()[np.newaxis, :]
labels_col = labels[:, np.newaxis]

The above code block produces respectively a row array of shape (1,9)

array([[1, 2, 3, 4, 5, 6, 7, 8, 9]])

and a column array of shape (5,1)

array([[2],
       [4],
       [5],
       [6],
       [8]])

Now the two shapes are broadcastable (see this page), so we can perform elementwise comparison, e.g.

mask = labeled_row == labels_col

which returns a (5,9)-shaped boolean mask

array([[False,  True, False, False, False, False, False, False, False],
       [False, False, False,  True, False, False, False, False, False],
       [False, False, False, False,  True, False, False, False, False],
       [False, False, False, False, False,  True, False, False, False],
       [False, False, False, False, False, False, False,  True, False]])

In the case the assumption above is fullfilled, you'll have a number of True values per row equal to the number of times the corresponding label appears in your labeled_array. Nonetheless, you can also have all-False rows, e.g. when a label in labels never appears in your labeled_array.

To find out which labels actually appeared in your labeled_array, you can use np.nonzero on the boolean mask

indices = np.nonzero(mask)

which returns a tuple containing the row and column indices of the non-zero (i.e. True) elements

(array([0, 1, 2, 3, 4], dtype=int64), array([1, 3, 4, 5, 7], dtype=int64))

By construction, the first element of the tuple above tells you which labels actually appeared in your labeled_array, e.g.

appeared_labels = labels[indices[0]]

(note that you can have consecutive elements in appeared_labels if that specific label appeared more than once in your labeled_array).

We can now build and fill the output array:

out = np.zeros(labeled_array.size, dtype=int)
out[indices[1]] = labels[indices[0]]

and bring it back to the original shape

out = out.reshape(*labeled_array.shape)
array([[0, 2, 0],
       [4, 5, 6],
       [0, 8, 0]])
  • Related