Home > Back-end >  How to construct a Tensor with values at given indices inside graph mode?
How to construct a Tensor with values at given indices inside graph mode?

Time:07-09

I have a 1D tensor of length 128 with logits in it. For a custom loss, I'm trying to replace the 3 highest values with 1.0 and replace the rest with 0.0. This is inside a @tf.function, so I can't convert it to numpy and do the manipulation there.

I've come up with:

top_3 = tf.math.top_k(code, k=3)
indices = top_3.indices    
updates = tf.ones_like(indices)
new_code = tf.scatter_nd(indices, updates, tf.constant([128]))

But it gives me the error:

ValueError: Dimensions [3,1) of input[shape=[?]] = [] must match dimensions [0,1) of updates[shape=[3]] = [3]: Shapes must be equal rank, but are 0 and 1 for '{{node ScatterNd}} = ScatterNd[T=DT_INT32, Tindices=DT_INT32](TopKV2:1, ones_like_1, Const_3)' with input shapes: [3], [3], [1].

which I don't understand, because indices should have length 3, and so does updates. Whats the problem?

CodePudding user response:

Try:

import tensorflow as tf

code = tf.random.normal((128,))
top_3 = tf.math.top_k(code, k=3)
indices = top_3.indices

updates = tf.ones_like(indices, dtype=tf.float32)
new_code = tf.zeros_like(code)
new_code = tf.tensor_scatter_nd_update(new_code, indices[..., None], updates)
print(new_code)
  • Related