Home > Mobile >  How to use argmax to return indices into multidimensional ndarray that cannot be re-shaped into a ma
How to use argmax to return indices into multidimensional ndarray that cannot be re-shaped into a ma

Time:08-30

Given an array of 2, 9x9 images with 2 channels shaped like this:

img1 = img1 = np.arange(162).reshape(9,9,2).copy()
img2 = img1 * 2
batch = np.array([img1, img2])

I need to slice each image into 3x3x2 (stride=3) regions and then locate and replace max elements of each slice. For the example above these elements are:

  • (:, 2, 2, :)
  • (:, 2, 5, :)
  • (:, 2, 8, :)
  • (:, 5, 2, :)
  • (:, 5, 5, :)
  • (:, 5, 8, :)
  • (:, 8, 2, :)
  • (:, 8, 5, :)
  • (:, 8, 8, :)

So far my solution is this:

batch_size, _, _, channels = batch.shape
region_size = 3

# For the (0, 0) region
region_slice = (slice(batch_size), slice(region_size), slice(region_size), slice(channels))
region = batch[region_slice]
new_values = np.arange(batch_size * channels)

# Flatten each channel of an image
region_3d = region.reshape(batch_size, region_size ** 2, channels)

region_3d_argmax = region_3d.argmax(axis=1)
region_argmax = (
    np.repeat(np.arange(batch_size), channels),
    *np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
    np.tile(np.arange(channels), batch_size)
)

# Find indices of max element for each channel
region_3d_argmax = region_3d.argmax(axis=1)

# Manually unravel indices
region_argmax = (
    np.repeat(np.arange(batch_size), channels),
    *np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
    np.tile(np.arange(channels), batch_size)
)

batch[region_slice][region_argmax] = new_values

There are two problems with this code:

  • Reshaping region may return a copy instead of view
  • Manual unraveling

What is the better way to perform this operation?

CodePudding user response:

With merging axes

The better way (on memory and hence performance efficiency) is to use advanced-indexing to create the appropriate indexing tuple -

m,n = idx.shape
indexer = np.arange(m)[:,None],idx,np.arange(n)
batch_3d[indexer].flat = ...# perform replacement with 1D array

The last step could be written differently by reshaping the replacement-array to the indexed shape (if not already so, otherwise skip) -

batch_3d[indexer] = replacement_array.reshape(m,n)

We can also use the built-in np.put_along_axis with p as the replacement array -

np.put_along_axis(batch_3d,idx[:,None,:],p.reshape(m,1,n),axis=1)

Note: The idx used in this post is the one generated from : idx = batch_3d.argmax(axis=1), hence we are skipping the manually unravel indices step.


Without merging axes

We would define helper funcs to achieve our argmax based replacements along multiple axes without merging axes that are not adjacent, as they will force copy.

def indexer_skip_one_axis(a, axis):
    return tuple(slice(None) if i!=axis else None for i in range(a.ndim))

def argmax_along_axes(a, axis):
    # a is input array
    # axis is tuple of axes along which argmax indices are to be computed
    argmax1 = (a.argmax(axis[0]))[indexer_skip_one_axis(a,axis[0])]
    val_argmax1 = np.take_along_axis(a,argmax1,axis=axis[0])
    argmax2 = (val_argmax1.argmax(axis[1]))[indexer_skip_one_axis(a,axis[1])]
    val_argmax2 = np.take_along_axis(argmax1,argmax2,axis=axis[1])
    r = list(np.ix_(*[np.arange(i) for i in a.shape]))
    r[axis[0]] = val_argmax2
    r[axis[1]] = argmax2
    return tuple(r)

Hence, to solve our case to do all the replacements would be -

m,n,r,s = batch.shape
batch6D = batch.reshape(m,n//3,3,r//3,3,s)
batch6D[argmax_along_axes(batch6D, axis=(2,4))] = new_values.reshape(2,1,1,1,1,2)
out = batch6D.reshape(m,n,r,s)
  • Related