Home > database >  Why do you multiply the two images to see the correlation?
Why do you multiply the two images to see the correlation?

Time:09-24

I have some questions about the CP Viton module:

feature_A = feature_A.transpose(2,3).contiguous().view(b,c,h*w)

feature_B = feature_B.view(b,c,h*w).transpose(1,2)

# perform matrix mult.
feature_mul = torch.bmm(feature_B,feature_A)
print(feature_mul.size()) #torch.Size([4, 192, 192])

at this code

  1. For the multiplication of matrices, I don't know why they make it like b,hw,hw.

  2. It is said that multiplying the shape of the image as follows is to extract the correlation, but I don't know why. I'm talking about the bmm part.

CodePudding user response:

The spatial correlation consists of computing the dot product of feature vectors of every (feature_A[k,:,i], feature_B[k,:,j]) feature pair. As such you first need to flatten the spatial dimension which results in a dimension of size h*w on both tensors. Your two operands will have a shape of (b, c, hw), and (b, hw, c). As a result of applying bmm, you end up with a tensor shaped (b, hw, hw)

  • Related