I'm trying to recreate a transformer written in Pytorch and implement it in Tensorflow. The problem is that despite both the documentation for the Pytorch version and Tensorflow version, they still come out pretty differently. I wrote a little code snippet to show the issue:
import torch
import tensorflow as tf
import numpy as np
class TransformerLayer(tf.Module):
def __init__(self, d_model, nhead, dropout=0):
super(TransformerLayer, self).__init__()
self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)
batch_size = 2
seq_length = 5
d_model = 10
src = np.random.uniform(size=(batch_size, seq_length, d_model))
srcTF = tf.convert_to_tensor(src)
srcPT = torch.Tensor(src.reshape((seq_length, batch_size, d_model)))
self_attnTF = tf.keras.layers.MultiHeadAttention(key_dim=10, num_heads=5, dropout=0)
transformer_encoder = TransformerLayer(d_model=10, nhead=5, dropout=0.0)
output, scores = self_attnTF(srcTF, srcTF, srcTF, return_attention_scores=True)
print("Tensorflow Attendtion outputs:", output)
print("Tensorflow (averaged) weights:", tf.math.reduce_mean(scores, 1))
print("Torch Attendtion outputs:", transformer_encoder.self_attn(srcPT,srcPT,srcPT)[0])
print("Torch attention output weights:", transformer_encoder.self_attn(srcPT,srcPT,srcPT)[1])
and the result is:
Tensorflow Attendtion outputs: tf.Tensor(
[[[ 0.02602757 -0.14134401 0.00855263 0.4735083 -0.01851891
-0.20382246 -0.18152176 -0.21076852 0.08623976 -0.33548725]
[ 0.02607442 -0.1403394 0.00814065 0.47415024 -0.01882939
-0.20353754 -0.18291879 -0.21234266 0.08595885 -0.33613583]
[ 0.02524654 -0.14096384 0.00870436 0.47411725 -0.01800703
-0.20486829 -0.18163288 -0.21082559 0.08571021 -0.3362339 ]
[ 0.02518575 -0.14039244 0.0090138 0.47431853 -0.01775141
-0.20391947 -0.18138805 -0.2118245 0.08432849 -0.33521986]
[ 0.02556361 -0.14039293 0.00876258 0.4746476 -0.01891363
-0.20398234 -0.18229616 -0.21147579 0.08555281 -0.33639923]]
[[ 0.07844199 -0.1614371 0.01649148 0.5287745 0.05126739
-0.13851154 -0.09829871 -0.1621251 0.01922669 -0.2428589 ]
[ 0.07844222 -0.16024739 0.01805423 0.52941847 0.04975721
-0.13537636 -0.09829231 -0.16129729 0.01979005 -0.24491176]
[ 0.07800542 -0.160701 0.01677295 0.52902794 0.05082911
-0.13843337 -0.09805533 -0.16165744 0.01928401 -0.24327613]
[ 0.07815789 -0.1600025 0.01757433 0.5291927 0.05032986
-0.1368022 -0.09849522 -0.16172451 0.01929555 -0.24438493]
[ 0.0781548 -0.16028519 0.01764914 0.52846324 0.04941286
-0.13746066 -0.09787872 -0.16141161 0.01994199 -0.2440269 ]]], shape=(2, 5, 10), dtype=float32)
Tensorflow (averaged) weights: tf.Tensor(
[[[0.199085 0.20275716 0.20086522 0.19873264 0.19856 ]
[0.2015336 0.19960018 0.20218948 0.19891861 0.19775811]
[0.19906266 0.20318432 0.20190334 0.19812575 0.19772394]
[0.20074987 0.20104568 0.20269363 0.19744729 0.19806348]
[0.19953248 0.20176074 0.20314851 0.19782843 0.19772986]]
[[0.2010009 0.20053487 0.20004745 0.20092985 0.19748697]
[0.20034568 0.20035927 0.19955876 0.20062163 0.19911464]
[0.19967113 0.2006859 0.20012529 0.20047483 0.19904283]
[0.20132652 0.19996871 0.20019794 0.20008174 0.19842513]
[0.2006393 0.20000939 0.19938737 0.20054278 0.19942114]]], shape=(2, 5, 5), dtype=float32)
Torch Attendtion outputs: tensor([[[ 0.1097, -0.4467, -0.0719, -0.1779, -0.0766, -0.1247, 0.1557,
0.0051, -0.3932, -0.1323],
[ 0.1264, -0.3822, 0.0759, -0.0335, -0.1084, -0.1539, 0.1475,
-0.0272, -0.4235, -0.1744]],
[[ 0.1122, -0.4502, -0.0747, -0.1796, -0.0756, -0.1271, 0.1581,
0.0049, -0.3964, -0.1340],
[ 0.1274, -0.3823, 0.0754, -0.0356, -0.1091, -0.1547, 0.1477,
-0.0272, -0.4252, -0.1752]],
[[ 0.1089, -0.4427, -0.0728, -0.1746, -0.0756, -0.1202, 0.1501,
0.0031, -0.3894, -0.1242],
[ 0.1263, -0.3820, 0.0718, -0.0374, -0.1063, -0.1562, 0.1485,
-0.0271, -0.4233, -0.1761]],
[[ 0.1061, -0.4369, -0.0685, -0.1696, -0.0772, -0.1173, 0.1454,
0.0012, -0.3860, -0.1201],
[ 0.1265, -0.3820, 0.0762, -0.0325, -0.1082, -0.1560, 0.1501,
-0.0271, -0.4249, -0.1779]],
[[ 0.1043, -0.4402, -0.0705, -0.1719, -0.0791, -0.1205, 0.1508,
0.0018, -0.3895, -0.1262],
[ 0.1260, -0.3805, 0.0775, -0.0298, -0.1083, -0.1547, 0.1494,
-0.0276, -0.4242, -0.1768]]], grad_fn=<AddBackward0>)
Torch attention output weights: tensor([[[0.2082, 0.2054, 0.1877, 0.1956, 0.2031],
[0.2100, 0.2079, 0.1841, 0.1943, 0.2037],
[0.2007, 0.1995, 0.1929, 0.1999, 0.2070],
[0.1995, 0.1950, 0.1976, 0.2002, 0.2077],
[0.1989, 0.1969, 0.1970, 0.2024, 0.2048]],
[[0.2095, 0.1902, 0.1987, 0.2027, 0.1989],
[0.2090, 0.1956, 0.1997, 0.2004, 0.1952],
[0.2047, 0.1869, 0.2006, 0.2121, 0.1957],
[0.2073, 0.1953, 0.1982, 0.2014, 0.1978],
[0.2089, 0.2003, 0.1953, 0.1957, 0.1998]]], grad_fn=<DivBackward0>)
The output weights look similar but the base attention outputs are way off. Is there any way to make the Tensorflow model come out more like the Pytorch one? Any help would be greatly appreciated!
CodePudding user response:
In MultiHeadAttention there is also a projection layer, like
Q = W_q @ input_query b_q
K = W_k @ input_keys b_k
V = W_v @ input_values b_v
Matrices W_q
, W_k
and W_v
and biases b_q
, b_k
, b_v
are initialized randomly, so difference in outputs should be expected (even between outputs of two distinct layers in pytorch on same input). After self-attention operation there is one more projection and it's also initialized randomly. Weights can be set manually in tensorflow by calling method set_weights
of self_attnTF
.
Correspondence between weights in tf.keras.layers.MultiHeadAttention
and nn.MultiheadAttention
not so clear, as an example: torch shares weights between heads, while tf keeps them unique. So if you are using weights of pretrained model from pytorch and try to put them in tensorflow model (for whatever reason) it'll certainly take more than five minutes.
Results should be the same if after initializing pytorch model and tensorflow model you step through their parameters and assign them identical values.