Home > Blockchain >  Vectorised pairwise distance
Vectorised pairwise distance

Time:06-09

TLDR: given two tensors t1 and t2 that represent b samples of a tensor with shape c,h,w (i.e, every tensor has shape b,c,h,w), i'm trying to calculate the pairwise distance between t1[i] and t2[j] for all i,j efficiently


some more context - I've extracted ResNet18 activations for both my train and test data (CIFAR10) and I'm trying to implement k-nearest-neighbours. A possible pseudo-code might be:

for te in test_activations:
    distances = []
    for tr in train_activations:
        distances.append(||te-tr||)
    neighbors = k_smallest_elements(distances)
    prediction(te) = majority_vote(labels(neighbors))

I'm trying to vectorise this process given batches from the test and train activations datasets. I've tried iterating the batches (and not the samples) and using torch.cdist(train_batch,test_batch), but I'm not quite sure how this function handles multi-dimensional tensors, as in the documentation it states

torch.cdist(x1, x2,...):
If x1 has shape BxPxM and x2 has shape BxRxM then the output will have shape BxPxR

Which doesn't seem to handle my case (see below)


A minimal example can be found here:

b,c,h,w = 1000,128,28,28 # actual dimensions in my problem
train_batch = torch.randn(b,c,h,w)
test_batch = torch.randn(b,c,h,w)
d = torch.cdist(train_batch,test_batch)

You can think of test_batch and train_batch as the tensors in the for loop for test_batch in train: for train_batch in test:..., and the expected output would have a shape (b,).

CodePudding user response:

It is common to have to reshape your data before feeding it to a builtin PyTorch operator. As you've said torch.cdist works with two inputs shaped (B, P, M) and (B, R, M) and returns a tensor shaped (B, P, R).

Instead, you have two tensors shaped the same way: (b, c, h, w). If we match those dimensions we have: B=b, M=c, while P=h*w (from the 1st tensor) and R=h*w (from the 2nd tensor). This requires flattening the spatial dimensions together and swapping the last two axes. Something like:

>>> x1 = train_batch.flatten(2).transpose(1,2)
>>> x2 = test_batch.flatten(2).transpose(1,2)

>>> d = torch.cdist(x1, x2)

Now d contains distance between all possible pairs (train_batch[b, :, iy, ix], test_batch[b, :, jy, jx]) and is shaped (b, h*w, h*w).

You can then apply a knn using argmax to retrieve the k closest neighbour from one element of the training batch to the test batch.

  • Related