Home > Net >  What is the fastest way of calculate cosine similarity between rows of two same shape matrices
What is the fastest way of calculate cosine similarity between rows of two same shape matrices

Time:04-28

For example, I have two 2D array as follow:

X = array([[4, 4, 4, 2],
   [3, 1, 2, 2],
   [1, 3, 3, 3],
   [1, 3, 1, 2]])
Y = array([[2, 1, 1, 4],
   [2, 1, 1, 1],
   [4, 1, 4, 4],
   [4, 2, 3, 4]])

I want to calculate cosine simarity between rows of X and Y. such as

def cos(feats1, feats2):
    """
    Computing cosine distance
    For similarity
    """
    cos = np.dot(feats1, feats2) / (np.linalg.norm(feats1) * np.linalg.norm(feats2))
    return cos

for i in range(a.shape[0]):
    print(cos(a[i,:],b[i,:]))

Right now, I am using for loop to calculate cos distance between vectors. But the size of X and Y is like (1200000000, 512), it takes realy long time to calculate just using for loop.

My question is how i can utilize the power of algebra and numpy to speed up this process.

Or any other method that can perform this calculation more efficient.

Thanks

CodePudding user response:

Possible in one single line: the trick is to just specify the axis over which perform the norm and the dot product.

X = np.random.randn(3,2)
Y = np.random.randn(3,2)
(X * Y).sum(axis=1) / np.linalg.norm(X, axis=1) / np.linalg.norm(Y, axis=1)

The first part, (X * Y).sum(axis=1) takes care of computing the dot product. axis=1 specify that we perform the dot product over the columns, i.e. get a result for each row (the datapoints).

The second part simply computes the norm of each vector, with the same method.

CodePudding user response:

If you only want to use numpy, make good use of broadcasting:

>>> def cos(x, y):
...     return (x * y).sum(axis=1) / (np.linalg.norm(x, axis=1) * np.linalg.norm(y, axis=1))
...
>>> X = np.array([[4, 4, 4, 2],
...    [3, 1, 2, 2],
...    [1, 3, 3, 3],
...    [1, 3, 1, 2]])
>>> Y = np.array([[2, 1, 1, 4],
...    [2, 1, 1, 1],
...    [4, 1, 4, 4],
...    [4, 2, 3, 4]])
>>> cos(X, Y)
array([0.70957488, 0.97995789, 0.83692133, 0.80829038])
  • Related