Home > front end >  Eucledian distance matrix between two matrices
Eucledian distance matrix between two matrices

Time:11-22

I have the following function that calculates the eucledian distance between all combinations of the vectors in Matrix A and Matrix B

def distance_matrix(A,B):
    n=A.shape[1]
    m=B.shape[1]

    C=np.zeros((n,m))

    for ai, a in enumerate(A.T):
        for bi, b in enumerate(B.T): 
            C[ai][bi]=np.linalg.norm(a-b)
    return C

This works fine and creates an n*m-Matrix from a d*n-Matrix and a d*m-Matrix containing the eucledian distance between all combinations of the column vectors.

>>> print(A)
[[-1 -1  1  1  2]
 [ 1 -1  2 -1  1]] 
>>> print(B)
[[-2 -1  1  2]
 [-1  2  1 -1]]
>>> print(distance_matrix(A,B))
[[2.23606798 1.         2.         3.60555128]
 [1.         3.         2.82842712 3.        ]
 [4.24264069 2.         1.         3.16227766]
 [3.         3.60555128 2.         1.        ]
 [4.47213595 3.16227766 1.         2.        ]]

I spent some time looking for a numpy or scipy function to achieve this in a more efficient way. Is there such a function or what would be the vecotrized way to do this?

CodePudding user response:

You can use:

np.linalg.norm(A[:,:,None]-B[:,None,:],axis=0)

or (totaly equivalent but without in-built function)

((A[:,:,None]-B[:,None,:])**2).sum(axis=0)**0.5

We need a 5x4 final array so we extend our array this way:

A[:,:,None]               -> 2,5,1
                               ↑ ↓ 
B[:,None,:]               -> 2,1,4

A[:,:,None] - B[:,None,:] -> 2,5,4

and we apply our sum over the axis 0 to finally get a 5,4 ndarray.

CodePudding user response:

Yes, you can broadcast your vectors:

A = np.array([[-1, -1,  1,  1,  2], [ 1, -1,  2, -1,  1]])
B = np.array([[-2, -1,  1,  2], [-1,  2,  1, -1]])

C = np.linalg.norm(A.T[:, None, :] - B.T[None, :, :], axis=-1)
print(C)

array([[2.23606798, 1.        , 2.        , 3.60555128],
       [1.        , 3.        , 2.82842712, 3.        ],
       [4.24264069, 2.        , 1.        , 3.16227766],
       [3.        , 3.60555128, 2.        , 1.        ],
       [4.47213595, 3.16227766, 1.        , 2.        ]])

You can get an explanation of how it works here:

https://sparrow.dev/pairwise-distance-in-numpy/

  • Related