Home > Mobile >  python numpy maxpool: given an array and indices from argmax, returns max values
python numpy maxpool: given an array and indices from argmax, returns max values

Time:10-04

suppose I have an array called view:

array([[[[ 7,  9],
         [10, 11]],

        [[19, 18],
         [20, 16]]],


       [[[24,  5],
         [ 6, 10]],

        [[18, 11],
         [45, 12]]]])

as you may know from maxpooling, this is a view of the original input, and the kernel size is 2x2:

[[ 7,  9],  [[19, 18],
 [10, 11]],  [20, 16]]], ....

The goal is to find both max values and their indices. However, argmax only works on single axis, so I need to flatten view, i.e. using flatten=view.reshape(2,2,4):

array([[[ 7,  9, 10, 11], [19, 18, 20, 16]],

       [[24,  5,  6, 10], [18, 11, 45, 12]]])

Now, with the help I get from my previous question, I can find indices of max using inds = flatten.argmax(-1):

array([[3, 2],
       [0, 2]])

and values of max:

i, j = np.indices(flatten.shape[:-1])
flatten[i, j, inds]

>>> array([[11, 20],
           [24, 45]])

The problem
the problem arise when I flatten the view array. Since view array is a view of the original array i.e. view = as_strided(original, newshape, newstrides), so view and original shares the same data. However, reshape breaks it, so any change on view is not reflected on original. This is problematical during backpropagation.

My question
Given the array view and indices ind, I'd like to change max values in view to 1000, without using reshape, or any operation that breaks the 'bond' between view and original. Thanks for any help!!!

reproducible example

import numpy as np
from numpy.lib.stride_tricks import as_strided

original=np.array([[[7,9,19,18],[10,11,20,16]],[[24,5,18,11],[6,10,45,12]]],dtype=np.float64)
view=as_strided(original, shape=(2,1,2,2,2),strides=(64,32*2,8*2,32,8))

I'd like to change max values of each kernel in view to, say, 1000, that can be reflected on original, i.e. if I run view[0,0,0,0,0]=1000, then the first element of both view and original are 1000.

CodePudding user response:

how about this:

import numpy as np
view = np.array(
    [[[[ 7,  9],
       [10, 11]],
      [[19, 18],
       [20, 16]]],
     [[[24,  5],
       [ 6, 10]],
      [[18, 11],
       [45, 12]]]]
)
# Getting the indices of the max values
max0 = view.max(-2)
idx2 = view.argmax(-2)
idx2 = idx2.reshape(-1, idx2.shape[1])
max1 = max0.max(-1)
idx3 = max0.argmax(-1).flatten()
idx2 = idx2[np.arange(idx3.size), idx3]

idx0 = np.arange(view.shape[0]).repeat(view.shape[1])
idx1 = np.arange(view.shape[1]).reshape(1, -1).repeat(view.shape[0], 0).flatten()

# Replacing the maximal vlues with 1000
view[idx0, idx1, idx2, idx3] = 1000
print(f'view = \n{view}')

output:

view = 
[[[[   7    9]
   [  10 1000]]

  [[  19   18]
   [1000   16]]]


 [[[1000    5]
   [   6   10]]

  [[  18   11]
   [1000   12]]]]

Basically, idx{n} is the index of the maximal value in the last two dimensions for every matrix contained in the first two dimensions.

  • Related