Home > OS >  Trying to achieve same result with Pytorch and Tensorflow MultiheadAttention
Trying to achieve same result with Pytorch and Tensorflow MultiheadAttention

Time:04-21

I'm trying to recreate a transformer written in Pytorch and implement it in Tensorflow. The problem is that despite both the documentation for the Pytorch version and Tensorflow version, they still come out pretty differently. I wrote a little code snippet to show the issue:

import torch
import tensorflow as tf
import numpy as np

class TransformerLayer(tf.Module):
    def __init__(self, d_model, nhead, dropout=0):
        super(TransformerLayer, self).__init__()
        self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)

batch_size = 2
seq_length = 5
d_model = 10

src = np.random.uniform(size=(batch_size, seq_length, d_model))
srcTF = tf.convert_to_tensor(src)
srcPT = torch.Tensor(src.reshape((seq_length, batch_size, d_model)))

self_attnTF = tf.keras.layers.MultiHeadAttention(key_dim=10, num_heads=5, dropout=0)
transformer_encoder = TransformerLayer(d_model=10, nhead=5, dropout=0.0)

output, scores = self_attnTF(srcTF, srcTF, srcTF, return_attention_scores=True)
print("Tensorflow Attendtion outputs:", output)
print("Tensorflow (averaged) weights:", tf.math.reduce_mean(scores, 1))
print("Torch Attendtion outputs:", transformer_encoder.self_attn(srcPT,srcPT,srcPT)[0])
print("Torch attention output weights:", transformer_encoder.self_attn(srcPT,srcPT,srcPT)[1])

and the result is:

Tensorflow Attendtion outputs: tf.Tensor(
[[[ 0.02602757 -0.14134401  0.00855263  0.4735083  -0.01851891
   -0.20382246 -0.18152176 -0.21076852  0.08623976 -0.33548725]
  [ 0.02607442 -0.1403394   0.00814065  0.47415024 -0.01882939
   -0.20353754 -0.18291879 -0.21234266  0.08595885 -0.33613583]
  [ 0.02524654 -0.14096384  0.00870436  0.47411725 -0.01800703
   -0.20486829 -0.18163288 -0.21082559  0.08571021 -0.3362339 ]
  [ 0.02518575 -0.14039244  0.0090138   0.47431853 -0.01775141
   -0.20391947 -0.18138805 -0.2118245   0.08432849 -0.33521986]
  [ 0.02556361 -0.14039293  0.00876258  0.4746476  -0.01891363
   -0.20398234 -0.18229616 -0.21147579  0.08555281 -0.33639923]]

 [[ 0.07844199 -0.1614371   0.01649148  0.5287745   0.05126739
   -0.13851154 -0.09829871 -0.1621251   0.01922669 -0.2428589 ]
  [ 0.07844222 -0.16024739  0.01805423  0.52941847  0.04975721
   -0.13537636 -0.09829231 -0.16129729  0.01979005 -0.24491176]
  [ 0.07800542 -0.160701    0.01677295  0.52902794  0.05082911
   -0.13843337 -0.09805533 -0.16165744  0.01928401 -0.24327613]
  [ 0.07815789 -0.1600025   0.01757433  0.5291927   0.05032986
   -0.1368022  -0.09849522 -0.16172451  0.01929555 -0.24438493]
  [ 0.0781548  -0.16028519  0.01764914  0.52846324  0.04941286
   -0.13746066 -0.09787872 -0.16141161  0.01994199 -0.2440269 ]]], shape=(2, 5, 10), dtype=float32)
Tensorflow (averaged) weights: tf.Tensor(
[[[0.199085   0.20275716 0.20086522 0.19873264 0.19856   ]
  [0.2015336  0.19960018 0.20218948 0.19891861 0.19775811]
  [0.19906266 0.20318432 0.20190334 0.19812575 0.19772394]
  [0.20074987 0.20104568 0.20269363 0.19744729 0.19806348]
  [0.19953248 0.20176074 0.20314851 0.19782843 0.19772986]]

 [[0.2010009  0.20053487 0.20004745 0.20092985 0.19748697]
  [0.20034568 0.20035927 0.19955876 0.20062163 0.19911464]
  [0.19967113 0.2006859  0.20012529 0.20047483 0.19904283]
  [0.20132652 0.19996871 0.20019794 0.20008174 0.19842513]
  [0.2006393  0.20000939 0.19938737 0.20054278 0.19942114]]], shape=(2, 5, 5), dtype=float32)
Torch Attendtion outputs: tensor([[[ 0.1097, -0.4467, -0.0719, -0.1779, -0.0766, -0.1247,  0.1557,
           0.0051, -0.3932, -0.1323],
         [ 0.1264, -0.3822,  0.0759, -0.0335, -0.1084, -0.1539,  0.1475,
          -0.0272, -0.4235, -0.1744]],

        [[ 0.1122, -0.4502, -0.0747, -0.1796, -0.0756, -0.1271,  0.1581,
           0.0049, -0.3964, -0.1340],
         [ 0.1274, -0.3823,  0.0754, -0.0356, -0.1091, -0.1547,  0.1477,
          -0.0272, -0.4252, -0.1752]],

        [[ 0.1089, -0.4427, -0.0728, -0.1746, -0.0756, -0.1202,  0.1501,
           0.0031, -0.3894, -0.1242],
         [ 0.1263, -0.3820,  0.0718, -0.0374, -0.1063, -0.1562,  0.1485,
          -0.0271, -0.4233, -0.1761]],

        [[ 0.1061, -0.4369, -0.0685, -0.1696, -0.0772, -0.1173,  0.1454,
           0.0012, -0.3860, -0.1201],
         [ 0.1265, -0.3820,  0.0762, -0.0325, -0.1082, -0.1560,  0.1501,
          -0.0271, -0.4249, -0.1779]],

        [[ 0.1043, -0.4402, -0.0705, -0.1719, -0.0791, -0.1205,  0.1508,
           0.0018, -0.3895, -0.1262],
         [ 0.1260, -0.3805,  0.0775, -0.0298, -0.1083, -0.1547,  0.1494,
          -0.0276, -0.4242, -0.1768]]], grad_fn=<AddBackward0>)
Torch attention output weights: tensor([[[0.2082, 0.2054, 0.1877, 0.1956, 0.2031],
         [0.2100, 0.2079, 0.1841, 0.1943, 0.2037],
         [0.2007, 0.1995, 0.1929, 0.1999, 0.2070],
         [0.1995, 0.1950, 0.1976, 0.2002, 0.2077],
         [0.1989, 0.1969, 0.1970, 0.2024, 0.2048]],

        [[0.2095, 0.1902, 0.1987, 0.2027, 0.1989],
         [0.2090, 0.1956, 0.1997, 0.2004, 0.1952],
         [0.2047, 0.1869, 0.2006, 0.2121, 0.1957],
         [0.2073, 0.1953, 0.1982, 0.2014, 0.1978],
         [0.2089, 0.2003, 0.1953, 0.1957, 0.1998]]], grad_fn=<DivBackward0>)

The output weights look similar but the base attention outputs are way off. Is there any way to make the Tensorflow model come out more like the Pytorch one? Any help would be greatly appreciated!

CodePudding user response:

In MultiHeadAttention there is also a projection layer, like

Q = W_q @ input_query   b_q
K = W_k @ input_keys   b_k
V = W_v @ input_values   b_v

Matrices W_q, W_k and W_v and biases b_q, b_k, b_v are initialized randomly, so difference in outputs should be expected (even between outputs of two distinct layers in pytorch on same input). After self-attention operation there is one more projection and it's also initialized randomly. Weights can be set manually in tensorflow by calling method set_weights of self_attnTF.

Correspondence between weights in tf.keras.layers.MultiHeadAttention and nn.MultiheadAttention not so clear, as an example: torch shares weights between heads, while tf keeps them unique. So if you are using weights of pretrained model from pytorch and try to put them in tensorflow model (for whatever reason) it'll certainly take more than five minutes.

Results should be the same if after initializing pytorch model and tensorflow model you step through their parameters and assign them identical values.

  • Related