Home > OS >  Row wise outer addition
Row wise outer addition

Time:10-02

Given two numpy arrays: A of shape (m, k) and B of shape (m, n). I would like to compute an array C of shape (m, k, n) where every row r of C contains the outer addition of the row r of A and the row r of B. I can do that using a for loop as follows:

 import numpy as np

 A = np.array([[ 0,  1,  2,  3,  4],
               [ 5,  6,  7,  8,  9],
               [10, 11, 12, 13, 14]])

 B = np.array([[1., 1., 1., 1., 1., 1.],
               [2., 2., 2., 2., 2., 2.],
               [3., 3., 3., 3., 3., 3.]])
 A.shape
 Out[655]: (3, 5)

 B.shape
 Out[656]: (3, 6)

C = np.zeros((A.shape[0], A.shape[1],B.shape[1]))
for i in range(A.shape[0]):
    C[i] = A[i][:,None]  B[i]

C
Out[659]: 
array([[[ 1.,  1.,  1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.,  2.,  2.],
        [ 3.,  3.,  3.,  3.,  3.,  3.],
        [ 4.,  4.,  4.,  4.,  4.,  4.],
        [ 5.,  5.,  5.,  5.,  5.,  5.]],

       [[ 7.,  7.,  7.,  7.,  7.,  7.],
        [ 8.,  8.,  8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.,  9.,  9.],
        [10., 10., 10., 10., 10., 10.],
        [11., 11., 11., 11., 11., 11.]],

       [[13., 13., 13., 13., 13., 13.],
        [14., 14., 14., 14., 14., 14.],
        [15., 15., 15., 15., 15., 15.],
        [16., 16., 16., 16., 16., 16.],
        [17., 17., 17., 17., 17., 17.]]])

But is there a way to vectorize the above code to get rid of the for loop?

CodePudding user response:

You could use broadcasting tricks, A has a shape of (m, k) and B has a shape of (m, n), you are looking to insert a dimension on A and B in opposite ways such that the resulting shape are (m, k, 1) for one and (m, 1, n) for the other. Then applying the operator will perform the outer operation:

>>> A[...,None]   B[:,None]
array([[[ 1.,  1.,  1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.,  2.,  2.],
        [ 3.,  3.,  3.,  3.,  3.,  3.],
        [ 4.,  4.,  4.,  4.,  4.,  4.],
        [ 5.,  5.,  5.,  5.,  5.,  5.]],

       [[ 7.,  7.,  7.,  7.,  7.,  7.],
        [ 8.,  8.,  8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.,  9.,  9.],
        [10., 10., 10., 10., 10., 10.],
        [11., 11., 11., 11., 11., 11.]],

       [[13., 13., 13., 13., 13., 13.],
        [14., 14., 14., 14., 14., 14.],
        [15., 15., 15., 15., 15., 15.],
        [16., 16., 16., 16., 16., 16.],
        [17., 17., 17., 17., 17., 17.]]])
  • Related