Home > Software engineering >  Vectorized sum of different numbers of rows in a numpy array or pytorch tensor
Vectorized sum of different numbers of rows in a numpy array or pytorch tensor

Time:09-01

I have an (2, 4, 3) numpy array

M = np.array([
    [[1, 10, 100],
     [2, 20, 200],
     [3, 30, 300],
     [4, 40, 400]],        
    [[5, 50, 500],
     [6, 60, 600],
     [7, 70, 700],
     [8, 80, 800]]
])

and I want to obtain a sum of m rows in the first subarray and n rows in the second subarray, let it would be 2, 3 and 4 rows in the first subarray and 1 and 2 rows in the second one

np.array([
    [[9, 90, 900]],

    [[11, 110, 1100]]
])

How to do that in vectorized way? Then how to obtain vectorized min/max over rows in the same case of different number of rows?

CodePudding user response:

You could use np.add.reduceat after applying the appropriate index. In fact, you need to define your index clearly first. I recommend using the normal fancy indexing format that is returned by functions like np.where:

p = [0, 0, 0, 1, 1] # Which plane to grab
r = [1, 2, 3, 0, 1] # Which row to grab in that channel
m = M[p, r, :]

>>> m
array([[  2,  20, 200],
       [  3,  30, 300],
       [  4,  40, 400],
       [  5,  50, 500],
       [  6,  60, 600]])

Now you can easily determine the cut-points in r based on changes in p:

splits = np.r_[0, np.flatnonzero(np.diff(p))   1]

>>> splits
array([0, 3])

And apply:

>>> np.add.reduceat(m, splits, axis=0)
array([[   9,   90,  900],
       [  11,  110, 1100]])

For a given p and r, you can use a one-liner, which is not completely illegible (IMO):

np.add.reduceat(M[p, r, :], np.r_[0, np.flatnonzero(np.diff(p))], axis=0)

CodePudding user response:

If you have a list of rows to take with length equal to the first dimension of M, you could get away with something like:

>>> # start and stop indices, right bound not included
>>> rows = [(1,4), (0,2)]
>>> np.vstack((M[i, start:stop].sum(axis=0) for i, (start, stop) in enumerate(rows)))
array([[   9,   90,  900],
       [  11,  110, 1100]])

This assumes every element in M is being operated on, which is a requirement to ensure the shapes match.

  • Related