Home > Back-end >  Update random value of tensor using random indices
Update random value of tensor using random indices

Time:11-04

So, I have a list that contains some seed generated from somewhere else.

I need to use this seed to create a list of random indices and then create a new tensor with 0 everywhere except for those indices.

For now the code is like so:

import tensorflow as tf
size_for_layer = [800, 32, 51200, 64, 1605632, 512, 31744, 62]  # TODO : Renderlo autonomo
size_for_layer_submodel = [150, 6, 1950, 13, 261170, 410, 25420, 62]  # TODO : Renderlo autonomo
shape_for_layer = [[5, 5, 1, 32], [32], [5, 5, 32, 64], [64], [3136, 512], [512], [512, 62],
                   [62]]  # TODO : Renderlo autonomo
if __name__ == '__main__':
    seed = 1254  # Only for testing now
    index = 1
    # Tensor with X elements all 0s
    tensor_testing = tf.zeros(size_for_layer[index], tf.float32)
    # Tensor with Y random generated indices
    indices = tf.random.uniform(shape=[size_for_layer_submodel[index], ], minval=0,
                                maxval=size_for_layer[index], dtype=tf.dtypes.int64, seed=seed, name=None)


    values = tf.fill([tf.shape(indices)[0], ], 15.4)
    print(tensor_testing)
    print(indices)
    print(values)
    tensor_testing = tf.tensor_scatter_nd_update(tensor_testing, indices, values)

The error is this one:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Inner dimensions of output shape must match inner dimensions of updates shape. Output: [32] updates: [6] [Op:TensorScatterUpdate]

CodePudding user response:

You were really close. Something like this should do the trick for you:

import tensorflow as tf
size_for_layer = [800, 32, 51200, 64, 1605632, 512, 31744, 62]  # TODO : Renderlo autonomo
size_for_layer_submodel = [150, 6, 1950, 13, 261170, 410, 25420, 62]  # TODO : Renderlo autonomo
shape_for_layer = [[5, 5, 1, 32], [32], [5, 5, 32, 64], [64], [3136, 512], [512], [512, 62],
                   [62]]

seed = 1254  # Only for testing now
index = 1
# Tensor with X elements all 0s
tensor_testing = tf.zeros(size_for_layer[index], tf.float32)
# Tensor with Y random generated indices
indices = tf.random.uniform(shape=[size_for_layer_submodel[index], ], minval=0,
                            maxval=size_for_layer[index], dtype=tf.dtypes.int64, seed=seed, name=None)

values = tf.fill([tf.shape(indices)[0], ], 15.4)
print(tensor_testing)
print(indices)
print(values)

tensor_testing =  tf.tensor_scatter_nd_update(tensor_testing, tf.expand_dims(indices, 1), values)

print(tensor_testing)
tf.Tensor(
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.], shape=(32,), dtype=float32)
tf.Tensor([ 9  4 17 16 28  7], shape=(6,), dtype=int64)
tf.Tensor([15.4 15.4 15.4 15.4 15.4 15.4], shape=(6,), dtype=float32)
tf.Tensor(
[ 0.   0.   0.   0.  15.4  0.   0.  15.4  0.  15.4  0.   0.   0.   0.
  0.   0.  15.4 15.4  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
 15.4  0.   0.   0. ], shape=(32,), dtype=float32)

Update: if you want to use an existing tensor instead of tf.fill, try:

some_tensor = tf.random.uniform((32,), maxval=20)
tensor_testing = tf.tensor_scatter_nd_update(tensor_testing, tf.expand_dims(indices, 1), tf.gather(some_tensor, indices))
tf.Tensor(
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.], shape=(32,), dtype=float32)
tf.Tensor([17 10 21  0 23 27], shape=(6,), dtype=int64)
tf.Tensor(
[19.655283    0.73073626  5.853808   10.993803    2.2646523   1.0580301
  3.7602305  13.081946   10.444834   13.0695915  19.739437   13.987379
 14.613118   14.325147    3.355515   15.57209    13.402302   17.103617
 12.819632    7.440541   16.09658     1.3479114   1.6912937   0.5928588
 13.784771   14.848431   18.457924   10.463061   13.597097    3.3686733
  8.239708   16.517185  ], shape=(32,), dtype=float32)
tf.Tensor(
[19.655283   0.         0.         0.         0.         0.
  0.         0.         0.         0.        19.739437   0.
  0.         0.         0.         0.         0.        17.103617
  0.         0.         0.         1.3479114  0.         0.5928588
  0.         0.         0.        10.463061   0.         0.
  0.         0.       ], shape=(32,), dtype=float32)
  • Related