Here I was pointed to use tf.TensorArray
instead of tf.Variable
or tf.queue.FIFOQueue
for making FIFO contained in custom layer. Is it an effective way? Exist any alternative here?
If it's the most effective method how can I replace self.queue.assign(tf.concat([self.queue[timesteps:, :], inputs], axis=0))
with methods of tf.TensorArray
?
Code
class FIFOLayer(Layer):
def __init__(self, window_size, **kwargs):
super(FIFOLayer, self).__init__(**kwargs)
self.window_size = window_size
self.count = 0
def build(self, input_shape):
super(FIFOLayer, self).build(input_shape)
self.queue = self.add_weight(
name="queue",
shape=(self.window_size, input_shape[-1]),
initializer=tf.initializers.Constant(value=np.nan),
trainable=False,
)
def call(self, inputs, training):
timesteps = tf.shape(inputs)[0]
# check if batch_size is more than queue capacity
if timesteps > self.window_size:
raise ValueError()
# 1. append new state to queue
self.queue.assign(tf.concat([self.queue[timesteps:, :], inputs], axis=0))
self.count = timesteps
# 2. feed-forward
if self.count < self.window_size:
# generate mask
attention_mask = tf.cast(
tf.math.reduce_all(
tf.math.logical_not(tf.math.is_nan(self.queue)), axis=-1
),
dtype=tf.float32,
)
attention_mask = tf.matmul(
attention_mask[..., tf.newaxis],
attention_mask[..., tf.newaxis],
transpose_b=True,
)
return self.queue[tf.newaxis, ...], attention_mask
# !!! check overflow
elif self.count > self.window_size:
self.count = self.window_size
return self.queue[tf.newaxis, ...], None
@property
def is_full(self):
return self.count == self.window_size
def clear(self):
self.count = 0
self.queue.assign(tf.fill(self.queue.shape, np.nan))
l = FIFOLayer(window_size=10)
for i in range(6):
x = tf.random.normal((2, 12))
y = l(x)
print(y)
print(l.is_full, "\n\n")
l.clear()
print(l(x))
print(l.is_full, "\n\n")
CodePudding user response:
Using tf.TensorArray
, you can try something like this:
import tensorflow as tf
import numpy as np
tf.random.set_seed(111)
class FIFOLayer(tf.keras.layers.Layer):
def __init__(self, window_size, **kwargs):
super(FIFOLayer, self).__init__(**kwargs)
self.window_size = window_size
self.count = 0
def build(self, input_shape):
super(FIFOLayer, self).build(input_shape)
self.queue_array = tf.TensorArray(dtype=tf.float32, size=self.window_size)
self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.constant(np.nan)*tf.ones((self.window_size, input_shape[-1])))
def call(self, inputs, training):
timesteps = tf.shape(inputs)[0]
# check if batch_size is more than queue capacity
if timesteps > self.window_size:
raise ValueError()
self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.concat([self.queue_array.gather(tf.range(timesteps, self.window_size)), inputs], axis=0))
queue_tensor = self.queue_array.stack()
self.count = timesteps
# 2. feed-forward
if self.count < self.window_size:
# generate mask
attention_mask = tf.cast(
tf.math.reduce_all(
tf.math.logical_not(tf.math.is_nan(queue_tensor)), axis=-1
),
dtype=tf.float32,
)
attention_mask = tf.matmul(
attention_mask[..., tf.newaxis],
attention_mask[..., tf.newaxis],
transpose_b=True,
)
return queue_tensor[tf.newaxis, ...], attention_mask
# !!! check overflow
elif self.count > self.window_size:
self.count = self.window_size
return queue_tensor[tf.newaxis, ...], None
@property
def is_full(self):
return self.count == self.window_size
def clear(self):
self.count = 0
shape = tf.shape(self.queue_array.stack())[-1]
self.queue_array = tf.TensorArray(dtype=tf.float32, size=self.window_size)
self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.constant(np.nan)*tf.ones((self.window_size, shape)))
l = FIFOLayer(window_size=10)
for i in range(6):
x = tf.random.normal((2, 12))
y = l(x)
print(y)
print(l.is_full, "\n\n")
l.clear()
print(l(x))
print(l.is_full, "\n\n")
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 0.7558127 , 1.5447265 , 1.6315602 , -0.19868968,
0.08828261, 0.01711658, -1.8133892 , 0.12930395,
0.47128937, 0.08567389, -1.7158676 , -0.5843805 ],
[-0.7664911 , -0.7145203 , -1.089696 , 0.14649415,
0.03585422, 0.9916008 , 0.9384322 , 0.34755042,
-0.09592161, 0.76490027, -1.2517685 , -1.5740465 ]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 0.7558127 , 1.5447265 , 1.6315602 , -0.19868968,
0.08828261, 0.01711658, -1.8133892 , 0.12930395,
0.47128937, 0.08567389, -1.7158676 , -0.5843805 ],
[-0.7664911 , -0.7145203 , -1.089696 , 0.14649415,
0.03585422, 0.9916008 , 0.9384322 , 0.34755042,
-0.09592161, 0.76490027, -1.2517685 , -1.5740465 ],
[-0.31995258, -0.43669155, -0.28932425, -0.06870204,
-0.01291991, 1.171546 , 0.75079876, -0.7693662 ,
0.05902815, 0.60606545, -1.1038904 , -0.99837613],
[-0.6687948 , 0.22192897, -0.02249479, -0.08962449,
1.2408841 , 0.119805 , -0.53699344, 1.020805 ,
0.9610218 , 0.6133564 , -0.4358486 , 2.733222 ]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 0.7558127 , 1.5447265 , 1.6315602 , -0.19868968,
0.08828261, 0.01711658, -1.8133892 , 0.12930395,
0.47128937, 0.08567389, -1.7158676 , -0.5843805 ],
[-0.7664911 , -0.7145203 , -1.089696 , 0.14649415,
0.03585422, 0.9916008 , 0.9384322 , 0.34755042,
-0.09592161, 0.76490027, -1.2517685 , -1.5740465 ],
[-0.31995258, -0.43669155, -0.28932425, -0.06870204,
-0.01291991, 1.171546 , 0.75079876, -0.7693662 ,
0.05902815, 0.60606545, -1.1038904 , -0.99837613],
[-0.6687948 , 0.22192897, -0.02249479, -0.08962449,
1.2408841 , 0.119805 , -0.53699344, 1.020805 ,
0.9610218 , 0.6133564 , -0.4358486 , 2.733222 ],
[-0.33772066, 0.80799913, -0.00896128, 1.606288 ,
1.1561627 , 0.17252289, 0.2451608 , 1.4633939 ,
-0.9294784 , 0.42795137, -0.3016553 , -1.1823792 ],
[ 0.30927372, 0.3482721 , 1.0262096 , -0.97228396,
-0.55333287, -0.7914886 , 1.0115404 , -0.5656188 ,
0.30958036, -0.8476673 , 2.4919312 , 0.9093976 ]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 0.7558127 , 1.5447265 , 1.6315602 , -0.19868968,
0.08828261, 0.01711658, -1.8133892 , 0.12930395,
0.47128937, 0.08567389, -1.7158676 , -0.5843805 ],
[-0.7664911 , -0.7145203 , -1.089696 , 0.14649415,
0.03585422, 0.9916008 , 0.9384322 , 0.34755042,
-0.09592161, 0.76490027, -1.2517685 , -1.5740465 ],
[-0.31995258, -0.43669155, -0.28932425, -0.06870204,
-0.01291991, 1.171546 , 0.75079876, -0.7693662 ,
0.05902815, 0.60606545, -1.1038904 , -0.99837613],
[-0.6687948 , 0.22192897, -0.02249479, -0.08962449,
1.2408841 , 0.119805 , -0.53699344, 1.020805 ,
0.9610218 , 0.6133564 , -0.4358486 , 2.733222 ],
[-0.33772066, 0.80799913, -0.00896128, 1.606288 ,
1.1561627 , 0.17252289, 0.2451608 , 1.4633939 ,
-0.9294784 , 0.42795137, -0.3016553 , -1.1823792 ],
[ 0.30927372, 0.3482721 , 1.0262096 , -0.97228396,
-0.55333287, -0.7914886 , 1.0115404 , -0.5656188 ,
0.30958036, -0.8476673 , 2.4919312 , 0.9093976 ],
[-0.44241378, -0.6971805 , -0.37439492, 1.0154608 ,
-0.34494257, 0.1988212 , -0.9541314 , -0.44339198,
0.162457 , -0.31033182, -0.34568167, 1.0341203 ],
[-0.89020306, -0.8646532 , 0.13348487, -0.6604107 ,
0.07642484, 1.3407826 , 0.79119945, -0.7598532 ,
0.85146165, -0.2791065 , -0.4600736 , 0.809218 ]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ 7.5581270e-01, 1.5447265e 00, 1.6315602e 00, -1.9868968e-01,
8.8282607e-02, 1.7116580e-02, -1.8133892e 00, 1.2930395e-01,
4.7128937e-01, 8.5673891e-02, -1.7158676e 00, -5.8438051e-01],
[-7.6649112e-01, -7.1452028e-01, -1.0896960e 00, 1.4649415e-01,
3.5854220e-02, 9.9160081e-01, 9.3843222e-01, 3.4755042e-01,
-9.5921606e-02, 7.6490027e-01, -1.2517685e 00, -1.5740465e 00],
[-3.1995258e-01, -4.3669155e-01, -2.8932425e-01, -6.8702042e-02,
-1.2919909e-02, 1.1715460e 00, 7.5079876e-01, -7.6936620e-01,
5.9028149e-02, 6.0606545e-01, -1.1038904e 00, -9.9837613e-01],
[-6.6879481e-01, 2.2192897e-01, -2.2494787e-02, -8.9624494e-02,
1.2408841e 00, 1.1980500e-01, -5.3699344e-01, 1.0208050e 00,
9.6102178e-01, 6.1335641e-01, -4.3584859e-01, 2.7332220e 00],
[-3.3772066e-01, 8.0799913e-01, -8.9612845e-03, 1.6062880e 00,
1.1561627e 00, 1.7252289e-01, 2.4516080e-01, 1.4633939e 00,
-9.2947841e-01, 4.2795137e-01, -3.0165529e-01, -1.1823792e 00],
[ 3.0927372e-01, 3.4827209e-01, 1.0262096e 00, -9.7228396e-01,
-5.5333287e-01, -7.9148859e-01, 1.0115404e 00, -5.6561881e-01,
3.0958036e-01, -8.4766728e-01, 2.4919312e 00, 9.0939760e-01],
[-4.4241378e-01, -6.9718051e-01, -3.7439492e-01, 1.0154608e 00,
-3.4494257e-01, 1.9882120e-01, -9.5413142e-01, -4.4339198e-01,
1.6245700e-01, -3.1033182e-01, -3.4568167e-01, 1.0341203e 00],
[-8.9020306e-01, -8.6465323e-01, 1.3348487e-01, -6.6041070e-01,
7.6424837e-02, 1.3407826e 00, 7.9119945e-01, -7.5985318e-01,
8.5146165e-01, -2.7910650e-01, -4.6007359e-01, 8.0921799e-01],
[-6.7833281e-01, 4.7877081e-02, -2.0416839e 00, -1.5634586e 00,
-5.1782840e-01, 5.2898288e-01, -1.4573561e 00, 4.6455118e-01,
-3.2871577e-01, -1.5697428e 00, 1.4454672e-01, 8.2387424e-01],
[ 2.5552011e-03, 1.2834518e 00, 4.1382611e-01, 1.6535892e 00,
7.8654990e-02, -1.2952465e-01, 3.6811054e-01, 1.1675907e 00,
9.6434945e-01, -4.2399967e-01, -1.3700709e-01, -5.2056974e-01]]],
dtype=float32)>, None)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[-3.1995258e-01, -4.3669155e-01, -2.8932425e-01, -6.8702042e-02,
-1.2919909e-02, 1.1715460e 00, 7.5079876e-01, -7.6936620e-01,
5.9028149e-02, 6.0606545e-01, -1.1038904e 00, -9.9837613e-01],
[-6.6879481e-01, 2.2192897e-01, -2.2494787e-02, -8.9624494e-02,
1.2408841e 00, 1.1980500e-01, -5.3699344e-01, 1.0208050e 00,
9.6102178e-01, 6.1335641e-01, -4.3584859e-01, 2.7332220e 00],
[-3.3772066e-01, 8.0799913e-01, -8.9612845e-03, 1.6062880e 00,
1.1561627e 00, 1.7252289e-01, 2.4516080e-01, 1.4633939e 00,
-9.2947841e-01, 4.2795137e-01, -3.0165529e-01, -1.1823792e 00],
[ 3.0927372e-01, 3.4827209e-01, 1.0262096e 00, -9.7228396e-01,
-5.5333287e-01, -7.9148859e-01, 1.0115404e 00, -5.6561881e-01,
3.0958036e-01, -8.4766728e-01, 2.4919312e 00, 9.0939760e-01],
[-4.4241378e-01, -6.9718051e-01, -3.7439492e-01, 1.0154608e 00,
-3.4494257e-01, 1.9882120e-01, -9.5413142e-01, -4.4339198e-01,
1.6245700e-01, -3.1033182e-01, -3.4568167e-01, 1.0341203e 00],
[-8.9020306e-01, -8.6465323e-01, 1.3348487e-01, -6.6041070e-01,
7.6424837e-02, 1.3407826e 00, 7.9119945e-01, -7.5985318e-01,
8.5146165e-01, -2.7910650e-01, -4.6007359e-01, 8.0921799e-01],
[-6.7833281e-01, 4.7877081e-02, -2.0416839e 00, -1.5634586e 00,
-5.1782840e-01, 5.2898288e-01, -1.4573561e 00, 4.6455118e-01,
-3.2871577e-01, -1.5697428e 00, 1.4454672e-01, 8.2387424e-01],
[ 2.5552011e-03, 1.2834518e 00, 4.1382611e-01, 1.6535892e 00,
7.8654990e-02, -1.2952465e-01, 3.6811054e-01, 1.1675907e 00,
9.6434945e-01, -4.2399967e-01, -1.3700709e-01, -5.2056974e-01],
[ 1.3070145e 00, -6.7240512e-01, 1.9308577e 00, 1.7688200e-03,
3.0533668e-01, 6.5813893e-01, 5.2471739e-01, 2.1659613e 00,
-8.7725663e-01, 3.5695407e-01, -1.2751107e 00, -7.7276069e-01],
[-4.3180370e-01, -1.1814500e 00, 2.4167557e-01, 5.7490116e-01,
5.6998456e-01, -7.4528801e-01, -9.1826969e-01, -7.3694932e-01,
-1.2400552e 00, 1.6947891e 00, -2.6127639e-01, 7.8419834e-01]]],
dtype=float32)>, None)
True
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 1.3070145e 00, -6.7240512e-01, 1.9308577e 00, 1.7688200e-03,
3.0533668e-01, 6.5813893e-01, 5.2471739e-01, 2.1659613e 00,
-8.7725663e-01, 3.5695407e-01, -1.2751107e 00, -7.7276069e-01],
[-4.3180370e-01, -1.1814500e 00, 2.4167557e-01, 5.7490116e-01,
5.6998456e-01, -7.4528801e-01, -9.1826969e-01, -7.3694932e-01,
-1.2400552e 00, 1.6947891e 00, -2.6127639e-01, 7.8419834e-01]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]], dtype=float32)>)
tf.Tensor(False, shape=(), dtype=bool)
On a side note, using tf.queue.FIFOQueue
is really slow.