I want to get rid of "for loop iteration" by using Pytorch functions in my code. But the formula is complicated and I can't find a clue. Can the "for loop iteration" in the below replaced with the Torch operation?
B=10
L=20
H=5
mat_A=torch.randn(B,L,L,H)
mat_B=torch.randn(L,B,B,H)
tmp_B=torch.zeros_like(mat_B)
for x in range(L):
for y in range(B):
for z in range(B):
tmp_B[:,y,z,:] =mat_B[x,y,z,:]*mat_A[z,x,:,:]
CodePudding user response:
This looks like a good setup for applying torch.einsum
. However, we first need to explicit the :
placeholders by defining each individual accumulation term.
In order to do so, consider the shape of your intermediate tensor results. The first, mat_B[x,y,z]
is shaped (H,)
, while the second mat_A[z,x,]
is shaped (L, H)
.
In pseudo-code your initial operation is as follows:
for x, y, z, l, h in LxBxBxLxH:
tmp_B[:,y,z,:] = mat_B[x,y,z,:]*mat_A[z,x,:,:]
Knowing this, we can reformulate your initial loop in pseudo-code as:
for x, y, z, l, h in LxBxBxLxH:
tmp_B[l,y,z,h] = mat_B[x,y,z,h]*mat_A[z,x,l,h]
Therefore, we can apply torch.einsum
by using the same notation as above:
>>> torch.einsum('xyzh,zxlh->lyzh', mat_B, mat_A)