I have some questions about the CP Viton module:
feature_A = feature_A.transpose(2,3).contiguous().view(b,c,h*w)
feature_B = feature_B.view(b,c,h*w).transpose(1,2)
# perform matrix mult.
feature_mul = torch.bmm(feature_B,feature_A)
print(feature_mul.size()) #torch.Size([4, 192, 192])
at this code
For the multiplication of matrices, I don't know why they make it like b,hw,hw.
It is said that multiplying the shape of the image as follows is to extract the correlation, but I don't know why. I'm talking about the
bmm
part.
CodePudding user response:
The spatial correlation consists of computing the dot product of feature vectors of every (feature_A[k,:,i], feature_B[k,:,j])
feature pair. As such you first need to flatten the spatial dimension which results in a dimension of size h*w
on both tensors. Your two operands will have a shape of (b, c, hw)
, and (b, hw, c)
. As a result of applying bmm
, you end up with a tensor shaped (b, hw, hw)