Home > Back-end >  How can I remove "for loop iteration" by using torch tensor operator?
How can I remove "for loop iteration" by using torch tensor operator?

Time:06-22

I want to get rid of "for loop iteration" by using Pytorch functions in my code. But the formula is complicated and I can't find a clue. Can the "for loop iteration" in the below replaced with the Torch operation?

B=10
L=20
H=5

mat_A=torch.randn(B,L,L,H)
mat_B=torch.randn(L,B,B,H)
tmp_B=torch.zeros_like(mat_B)

for x in range(L):
    for y in range(B):
        for z in range(B):
            tmp_B[:,y,z,:] =mat_B[x,y,z,:]*mat_A[z,x,:,:]

CodePudding user response:

This looks like a good setup for applying torch.einsum. However, we first need to explicit the : placeholders by defining each individual accumulation term.

In order to do so, consider the shape of your intermediate tensor results. The first, mat_B[x,y,z] is shaped (H,), while the second mat_A[z,x,] is shaped (L, H).

In pseudo-code your initial operation is as follows:

for x, y, z, l, h in LxBxBxLxH:
    tmp_B[:,y,z,:]  = mat_B[x,y,z,:]*mat_A[z,x,:,:]

Knowing this, we can reformulate your initial loop in pseudo-code as:

for x, y, z, l, h in LxBxBxLxH:
    tmp_B[l,y,z,h]  = mat_B[x,y,z,h]*mat_A[z,x,l,h]

Therefore, we can apply torch.einsum by using the same notation as above:

>>> torch.einsum('xyzh,zxlh->lyzh', mat_B, mat_A)
  • Related