Home > Back-end >  tf.while_loop() in call() of custom layer in TensorFlow 2
tf.while_loop() in call() of custom layer in TensorFlow 2

Time:03-16

I want to write a custom layer that applies a dense layer and then some specified functions to the output of that computation. I want to specify the functions that are applied to the individual outputs in a list, such that I can easily change them.

I'm trying to apply the functions inside a tf.while_loop, but I don't know how to access and write to the individual elements of dense_output_nodes.

dense_output_nodes[i] = ... doesn't work as it tells me that

TypeError: 'Tensor' object does not support item assignment

So I tried to tf.unstack before, which is the code below, but now when creating the layer with hidden_1 = ArithmeticLayer(unit_types=['id', 'sin', 'cos'])(inputs), I get the error that

TypeError: list indices must be integers or slices, not Tensor

because apparently TensorFlow converts i from tf.constant to tf.Tensor.

By now, I'm really struggling to see ways I can fix this. Is there some way I can get this to work? Or should I build the whole ArithmeticLayer as a combination of a Dense layer and a Lambda layer applying the custom functions?

class ArithmeticLayer(layers.Layer):
    # u = number of units
    
    def __init__(self, name=None, regularizer=None, unit_types=['id', 'sin', 'cos']):
        self.regularizer=regularizer
        super().__init__(name=name)
        self.u = len(unit_types)
        self.u_types = unit_types

    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape[-1], self.u),
                                 initializer='random_normal',
                                 regularizer=self.regularizer,
                                 trainable=True)
        self.b = self.add_weight(shape=(self.u,),
                                 initializer='random_normal',
                                 regularizer=self.regularizer,
                                 trainable=True)


    def call(self, inputs):
        # get the output nodes of the dense layer as a list
        dense_output_nodes = tf.matmul(inputs, self.w)   self.b
        dense_output_list = tf.unstack(dense_output_nodes, axis=1)
        
        # apply the function units
        i = tf.constant(0)
        def c(i):
            return tf.less(i, self.u)
        def b(i):
            dense_output_list[i] = tf.cond(self.u_types[i] == 'sin',
                                            lambda: tf.math.sin(dense_output_list[i]),
                                            lambda: dense_output_list[i]
                                           )
            dense_output_list[i] = tf.cond(self.u_types[i] == 'cos',
                                            lambda: tf.math.cos(dense_output_list[i]),
                                            lambda: dense_output_list[i]
                                           )
            return (tf.add(i, 1), )
        [i] = tf.while_loop(c, b, [i])
        
        final_output_nodes = tf.stack(dense_output_list, axis=1)
        return final_output_nodes

Thanks for any suggestions!

CodePudding user response:

Using tf.tensor_scatter_nd_update should do the trick if you want to apply certain functions column-wise across samples in a batch. Here is an example working in eager execution and graph mode:

import tensorflow as tf

class ArithmeticLayer(tf.keras.layers.Layer):
    # u = number of units
    
    def __init__(self, name=None, regularizer=None, unit_types=['id', 'sin', 'cos']):
        self.regularizer=regularizer
        super().__init__(name=name)
        self.u_types = tf.constant(unit_types)
        self.u_shape = tf.shape(self.u_types)

    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape[-1], self.u_shape[0]),
                                 initializer='random_normal',
                                 regularizer=self.regularizer,
                                 trainable=True)
        self.b = self.add_weight(shape=(self.u_shape[0],),
                                 initializer='random_normal',
                                 regularizer=self.regularizer,
                                 trainable=True)

    def call(self, inputs):
        dense_output_nodes = tf.matmul(inputs, self.w)   self.b
        d_shape = tf.shape(dense_output_nodes)
        i = tf.constant(0)
        c = lambda i, d: tf.less(i, self.u_shape[0])

        def b(i, d):
          d = tf.cond(unit_types[i] == 'sin', 
                lambda: tf.tensor_scatter_nd_update(d, tf.stack([tf.range(d_shape[0]), tf.repeat([i], d_shape[0])], axis=1), tf.math.sin(d[:, i])), 
                lambda: d)
          d = tf.cond(unit_types[i] == 'cos', 
                lambda: tf.tensor_scatter_nd_update(d, tf.stack([tf.range(d_shape[0]), tf.repeat([i], d_shape[0])], axis=1), tf.math.cos(d[:, i])), 
                lambda: d)
          return tf.add(i, 1), d
        _, dense_output_nodes = tf.while_loop(c, b, loop_vars=[i, dense_output_nodes])

        return dense_output_nodes
x = tf.random.normal((4, 3))
inputs = tf.keras.layers.Input((3,))
arithmetic = ArithmeticLayer()
outputs = arithmetic(inputs)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mse')
model.fit(x, tf.random.normal((4, 3)), batch_size=2)
2/2 [==============================] - 3s 11ms/step - loss: 1.4259
<keras.callbacks.History at 0x7fe50728c850>

CodePudding user response:

If you plan to use a different datastructure then try this.

import tensorflow as tf

i = tf.constant(0)
u_types = ["sin","cos"]
u_types_ta = tf.TensorArray(dtype=tf.string,size=1, dynamic_size=True,clear_after_read=False)
for i in range(0, len(u_types)):
    u_types_ta = u_types_ta.write(i, u_types[i])
u = len(u_types)

dense_output_list = [1.,2.]
dense_output_ta = tf.TensorArray(dtype=tf.float32,size=1, dynamic_size=True,clear_after_read=False)
for i in range(0, len(dense_output_list)):
    dense_output_ta = dense_output_ta.write(i, dense_output_list[i])

ta = tf.TensorArray(dtype=tf.float32,size=1, dynamic_size=True,clear_after_read=False)


def c(i,_):
    return tf.less(i, u)


def b(i,ta):
    ta.write(i, tf.cond(u_types_ta.read(i) == 'sin',
                                   lambda: tf.math.sin(dense_output_ta.read(i)),
                                   lambda: dense_output_ta.read(i)
                                   ))
    ta.write(i, tf.cond(u_types_ta.read(i) == 'cos',
                                   lambda: tf.math.cos(dense_output_ta.read(i)),
                                   lambda: dense_output_ta.read(i)
                                   ))
    return (tf.add(i, 1),ta)


i,_ = tf.while_loop(c, b, [i,ta])
print(ta.stack())
  • Related