Home > Software engineering >  Is there a way to speed up looping over numpy.where?
Is there a way to speed up looping over numpy.where?

Time:02-24

Imagine you have a segmentation map, where each object is identified by a unique index, e.g. looking similar to this:

enter image description here

For each object, I would like to save which pixels it covers, but I could only come up with the standard for loop so far. Unfortunately, for larger images with thousands of individual objects, this turns out to be very slow--for my real data at least. Can I somehow speed things up?

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from skimage.draw import random_shapes


# please ignore that this does not always produce 20 objects each with a
# unique color. it is simply a quick way to produce data that is similar to
# my problem that can also be visualized.
segmap, labels = random_shapes(
    (100, 100), 20, min_size=6, max_size=20, multichannel=False,
    intensity_range=(0, 20), num_trials=100,
)
segmap = np.ma.masked_where(segmap == 255, segmap)

object_idxs = np.unique(segmap)[:-1]
objects = np.empty(object_idxs.size, dtype=[('idx', 'i4'), ('pixels', 'O')])

# important bit here:

# this I can do in parallel
objects['idx'] = object_idxs
# but this I cannot. and it takes forever.
for i in range(object_idxs.size):
    objects[i]['pixels'] = np.where(segmap == i)

# just plotting here
fig, ax = plt.subplots(constrained_layout=True)
image = ax.imshow(
    segmap, cmap='tab20', norm=mpl.colors.Normalize(vmin=0, vmax=20)
)
fig.colorbar(image)
fig.show()

CodePudding user response:

Using np.where in a loop is not efficient algorithmically since the time complexity is O(s n m) where s = object_idxs.size and n, m = segmap.shape. This operation can be done in O(n m).

One solution using Numpy is to first select all the object pixel locations, then sort them based on their associated object in segmap, and finally split them based on the number of objects. Here is the code:

background = np.max(segmap)
mask = segmap != background
objects = segmap[mask]
uniqueObjects, counts = np.unique(objects, return_counts=True)
ordering = np.argsort(objects)
i, j = np.where(mask)
indices = np.vstack([i[ordering], j[ordering]])
indicesPerObject = np.split(indices, counts.cumsum()[:-1], axis=1)

objects = np.empty(object_idxs.size, dtype=[('idx', 'i4'), ('pixels', 'O')])
objects['idx'] = object_idxs
for i in range(object_idxs.size):
    # You could use `tuple(...)` to get the exact same type as the initial code here
    objects[i]['pixels'] = indicesPerObject[i]

CodePudding user response:

If I understand the question correctly, You would like to see where any object is located, right? So if we start with one matrix (that is, all shapes are in one array, where empty spaces are zeros and object one consists of 1s, object 2 of 2s etc.) then You can create a mask, showing which pixels (or values in a matrix) are non-zero like this:

my_array != 0

Does that answer Your question?

Edit for clarification

  • Related