Home > other >  Transformer Positional Encoding -- What is maxlen used for
Transformer Positional Encoding -- What is maxlen used for

Time:01-28

class PositionalEncoding(nn.Module):
    def __init__(self,
    emb_size: int,
    dropout: float,
    maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)\* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding\[:, 0::2\] = torch.sin(pos \* den)
        pos_embedding\[:, 1::2\] = torch.cos(pos \* den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)
    
    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding   self.pos_embedding[:token_embedding.size(0), :])

This code is in here

I know what positional encoding is used for, but is maxlen a constant value? Or does it vary depending on the batch size or the length of the data?

An example from NLP:

  [data_lenght, bacth_size]

  [256, 64]

  64*256 = 16,384 variables are obtained.

What I don't understand here is the 5000 maxlen value used in positional encoding has something to do with it.

Am I using it wrong?

Should maxlen be changed according to the example I gave?

CodePudding user response:

The purpose of it is to ensure that the input texts have same the same length so that the model can work. The number is generally chosen as the mean length of all texts, but it can be quite wasteful, so we can use any reasonable number. Texts that are longer than maxlen will be truncated to maxlen, which is in the example, 5000 words. You can read this explanation for more info.

  • Related