Home > Blockchain >  In tensorflow, is there a way for generating a new tensor from a given tensor in the given manner?
In tensorflow, is there a way for generating a new tensor from a given tensor in the given manner?

Time:10-02

In the code below:

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras.backend as kb

ip1 = keras.Input((20,), name='ip1')
ip2 = keras.Input((20,), name='ip2')
ip3 = keras.Input((20,), name='ip3')
ip4 = keras.Input((20,), name='ip4')
ip5 = keras.Input((20,), name='ip5')
ip6 = keras.Input((20,), name='ip6')
ip7 = keras.Input((20,), name='ip7')
ip8 = keras.Input((20,), name='ip8')
ip9 = keras.Input((20,), name='ip9')

ip_concat = kb.stack([ip1, ip2, ip3, ip4, ip5, ip6, ip7, ip8, ip9], axis=1)

rel_features = layers.Dense(128, input_shape=(9, 20))(ip_concat)

f1_enc = layers.Dense(64, activation='tanh', input_shape=(9, 128))(rel_features)

Here, the f1_enc has a shape of 9x64. Now, what I need to do can be understood from the example given below. For the sake of simplicity, instead of a 9x64 tensor, I have kept the initial tensor to be 3x3:

[[a, b, c],
 [d, e, f],
 [g, h, i]]

needs to be converted to a 3x3x6 tensor as follows:

[[[a, b, c, a, b, c],
  [a, b, c, d, e, f],
  [a, b, c, g, h, i]],

[[d, e, f, a, b, c],
 [d, e, f, d, e, f],
 [d, e, f, g, h, i]],

[[g, h, i, a, b, c],
 [g, h, i, d, e, f],
 [g, h, i, g, h, i]]]

I know that the tensors are immutable in tensorflow, so looping over the tensor and reassigning the values to a new tensor won't work.

Also, I tried fiddling with eager execution because the conversion of the tensor to a numpy matrix did not work. From what I understand now, the strategy to convert the tensor to numpy matrix does not work because we are working with keras layer.

While the code will need to be backpropagated to the ip_concat, it will not need to backpropagate through the new tensor. There is going to be a connection from ip_concat to another keras layer later.

CodePudding user response:

This might not be the most efficient way but:

import numpy as np
import tensorflow as tf

tensor = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(tensor)
# tf.Tensor(
# [[1 2 3]
#  [4 5 6]
#  [7 8 9]], shape=(3, 3), dtype=int32)

x = tensor.numpy()
tensor = tf.constant([[np.concatenate((j, i)) for i in x] for j in x])
print(tensor)
# tf.Tensor(
# [[[1 2 3 1 2 3]
#   [1 2 3 4 5 6]
#   [1 2 3 7 8 9]]
#  [[4 5 6 1 2 3]
#   [4 5 6 4 5 6]
#   [4 5 6 7 8 9]]
#  [[7 8 9 1 2 3]
#   [7 8 9 4 5 6]
#   [7 8 9 7 8 9]]], shape=(3, 3, 6), dtype=int32)

CodePudding user response:

Use a combination of tf.unstack, tf.tile, tf.reshape and tf.concat. Our initial tensor A:

<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]], dtype=int32)>

First, unstack the rows:

A1, A2, A3 = tf.unstack(A, axis=0)

Output:

(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>,
 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>,
 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 8, 9], dtype=int32)>)

Then tile and reshape:

L1, L2, L3 = [tf.reshape(tf.tile(X, [3]), (3,3)) for X in [A1, A2, A3]]

Output:

(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
 array([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]], dtype=int32)>,
 <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
 array([[4, 5, 6],
        [4, 5, 6],
        [4, 5, 6]], dtype=int32)>,
 <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
 array([[7, 8, 9],
        [7, 8, 9],
        [7, 8, 9]], dtype=int32)>)

Then concatenate with the initial tensor:

R1, R2, R3 = [tf.concat([X, A], axis=1) for X in [L1, L2, L3]]

Output:

(<tf.Tensor: shape=(3, 6), dtype=int32, numpy=
 array([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 4, 5, 6],
        [1, 2, 3, 7, 8, 9]], dtype=int32)>,
 <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
 array([[4, 5, 6, 1, 2, 3],
        [4, 5, 6, 4, 5, 6],
        [4, 5, 6, 7, 8, 9]], dtype=int32)>,
 <tf.Tensor: shape=(3, 6), dtype=int32, numpy=
 array([[7, 8, 9, 1, 2, 3],
        [7, 8, 9, 4, 5, 6],
        [7, 8, 9, 7, 8, 9]], dtype=int32)>)
  • Related