Given an array of 2, 9x9 images with 2 channels shaped like this:
img1 = img1 = np.arange(162).reshape(9,9,2).copy()
img2 = img1 * 2
batch = np.array([img1, img2])
I need to slice each image into 3x3x2 (stride=3) regions and then locate and replace max elements of each slice. For the example above these elements are:
(:, 2, 2, :)
(:, 2, 5, :)
(:, 2, 8, :)
(:, 5, 2, :)
(:, 5, 5, :)
(:, 5, 8, :)
(:, 8, 2, :)
(:, 8, 5, :)
(:, 8, 8, :)
So far my solution is this:
batch_size, _, _, channels = batch.shape
region_size = 3
# For the (0, 0) region
region_slice = (slice(batch_size), slice(region_size), slice(region_size), slice(channels))
region = batch[region_slice]
new_values = np.arange(batch_size * channels)
# Flatten each channel of an image
region_3d = region.reshape(batch_size, region_size ** 2, channels)
region_3d_argmax = region_3d.argmax(axis=1)
region_argmax = (
np.repeat(np.arange(batch_size), channels),
*np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
np.tile(np.arange(channels), batch_size)
)
# Find indices of max element for each channel
region_3d_argmax = region_3d.argmax(axis=1)
# Manually unravel indices
region_argmax = (
np.repeat(np.arange(batch_size), channels),
*np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
np.tile(np.arange(channels), batch_size)
)
batch[region_slice][region_argmax] = new_values
There are two problems with this code:
- Reshaping
region
may return a copy instead of view - Manual unraveling
What is the better way to perform this operation?
CodePudding user response:
With merging axes
The better way (on memory and hence performance efficiency) is to use advanced-indexing
to create the appropriate indexing tuple -
m,n = idx.shape
indexer = np.arange(m)[:,None],idx,np.arange(n)
batch_3d[indexer].flat = ...# perform replacement with 1D array
The last step could be written differently by reshaping the replacement-array to the indexed shape (if not already so, otherwise skip) -
batch_3d[indexer] = replacement_array.reshape(m,n)
We can also use the built-in np.put_along_axis
with p
as the replacement array -
np.put_along_axis(batch_3d,idx[:,None,:],p.reshape(m,1,n),axis=1)
Note: The idx
used in this post is the one generated from : idx = batch_3d.argmax(axis=1)
, hence we are skipping the manually unravel indices
step.
Without merging axes
We would define helper funcs to achieve our argmax based replacements along multiple axes without merging axes that are not adjacent, as they will force copy.
def indexer_skip_one_axis(a, axis):
return tuple(slice(None) if i!=axis else None for i in range(a.ndim))
def argmax_along_axes(a, axis):
# a is input array
# axis is tuple of axes along which argmax indices are to be computed
argmax1 = (a.argmax(axis[0]))[indexer_skip_one_axis(a,axis[0])]
val_argmax1 = np.take_along_axis(a,argmax1,axis=axis[0])
argmax2 = (val_argmax1.argmax(axis[1]))[indexer_skip_one_axis(a,axis[1])]
val_argmax2 = np.take_along_axis(argmax1,argmax2,axis=axis[1])
r = list(np.ix_(*[np.arange(i) for i in a.shape]))
r[axis[0]] = val_argmax2
r[axis[1]] = argmax2
return tuple(r)
Hence, to solve our case to do all the replacements would be -
m,n,r,s = batch.shape
batch6D = batch.reshape(m,n//3,3,r//3,3,s)
batch6D[argmax_along_axes(batch6D, axis=(2,4))] = new_values.reshape(2,1,1,1,1,2)
out = batch6D.reshape(m,n,r,s)