TLDR: given two tensors t1
and t2
that represent b
samples of a tensor with shape c,h,w
(i.e, every tensor has shape b,c,h,w
), i'm trying to calculate the pairwise distance between t1[i]
and t2[j]
for all i
,j
efficiently
some more context - I've extracted ResNet18
activations for both my train and test data (CIFAR10
) and I'm trying to implement k
-nearest-neighbours. A possible pseudo-code might be:
for te in test_activations:
distances = []
for tr in train_activations:
distances.append(||te-tr||)
neighbors = k_smallest_elements(distances)
prediction(te) = majority_vote(labels(neighbors))
I'm trying to vectorise this process given batches from the test and train activations datasets. I've tried iterating the batches (and not the samples) and using torch.cdist(train_batch,test_batch)
, but I'm not quite sure how this function handles multi-dimensional tensors, as in the documentation it states
torch.cdist(x1, x2,...):
Ifx1
has shapeBxPxM
andx2
has shapeBxRxM
then the output will have shapeBxPxR
Which doesn't seem to handle my case (see below)
A minimal example can be found here:
b,c,h,w = 1000,128,28,28 # actual dimensions in my problem
train_batch = torch.randn(b,c,h,w)
test_batch = torch.randn(b,c,h,w)
d = torch.cdist(train_batch,test_batch)
You can think of test_batch
and train_batch
as the tensors in the for loop for test_batch in train: for train_batch in test:...
, and the expected output would have a shape (b,)
.
CodePudding user response:
It is common to have to reshape your data before feeding it to a builtin PyTorch operator. As you've said torch.cdist
works with two inputs shaped (B, P, M)
and (B, R, M)
and returns a tensor shaped (B, P, R)
.
Instead, you have two tensors shaped the same way: (b, c, h, w)
. If we match those dimensions we have: B=b
, M=c
, while P=h*w
(from the 1st tensor) and R=h*w
(from the 2nd tensor). This requires flattening the spatial dimensions together and swapping the last two axes. Something like:
>>> x1 = train_batch.flatten(2).transpose(1,2)
>>> x2 = test_batch.flatten(2).transpose(1,2)
>>> d = torch.cdist(x1, x2)
Now d
contains distance between all possible pairs (train_batch[b, :, iy, ix], test_batch[b, :, jy, jx])
and is shaped (b, h*w, h*w)
.
You can then apply a knn using argmax
to retrieve the k closest neighbour from one element of the training batch to the test batch.