Home > front end >  reshaping tensors for multi head attention in pytorch - view vs transpose
reshaping tensors for multi head attention in pytorch - view vs transpose

Time:11-25

I'm learning about the attention operator in the deep learning domain. I understand that to compute multi head attention efficiently in parallel, the input tensors (query, key, value) have to be reshaped properly. Assuming query, key and value are three tensor of identical shape [N, L, D], in which

  • N is the batch size
  • L is the sequence length
  • D is the hidden/embedding size,

they should be turned into [N*N_H, L, D_H] tensors, where N_H is the number of heads for the attention layer and D_H is the embedding size of each head.

The pytorch code seems to do exactly that. Here below I post the code for reshaping the query tensor (key, value are equally deemed)

q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

I don't get why they perform both a view and a transpose call, when the result would be the same by just doing

q = q.contiguous().view(bsz * num_heads, tgt_len, head_dim)

Other than avoiding an additional function call, using view alone also guarantees that the resulting tensor is still contiguous in memory, whereas this doesn't hold (to the best of my knowledge) for transpose. I suppose working with contiguous data is beneficial whenever possible to make computations potentially faster (may lead to fewer memory accesses, better exploiting of spatial locality of data, etc.).

What's the use case for having a transpose call after a view?

CodePudding user response:

The results AREN'T necessarily the same:

a = torch.arange(0, 2 * 3 * 4)
b = a.view(2, 3, 4).transpose(1, 0)
#tensor([[[ 0,  1,  2,  3],
     [12, 13, 14, 15]],

    [[ 4,  5,  6,  7],
     [16, 17, 18, 19]],

    [[ 8,  9, 10, 11],
     [20, 21, 22, 23]]])

c = a.view(3, 2, 4)
#tensor([[[ 0,  1,  2,  3],
     [ 4,  5,  6,  7]],

    [[ 8,  9, 10, 11],
     [12, 13, 14, 15]],

    [[16, 17, 18, 19],
     [20, 21, 22, 23]]])
  • Related