In short, how do I translate this tensor:
[[1. 1. 1. 1.]
[1. 1. 1. 0.]
[1. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 1.]
[0. 1. 1. 1.]]
into this one:
[[1. 1. 1. 1.]
[1. 1. 1. 0.]
[1. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 1. 0. 0.]
[1. 1. 1. 0.]]
If the input tensor's 0th value on dimension 1 == 0, I want to reverse the entire dimension 1 of that specific sample; otherwise, leave it alone. I am actually working with a 3-dimensional tensor (batch, horizon, feature) but I simplified the example tensors here.
CodePudding user response:
Assuming you tried everything in your power to solve this, here is something that might be useful:
import tensorflow as tf
x = tf.constant([[1., 1., 1., 1.],
[1., 1., 1., 0.],
[1., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 1.],
[0., 0., 1., 1.],
[0., 1., 1., 1.]])
indices = tf.where(x[:, 0] == 0.0)
new_indices = tf.stack([tf.repeat(indices[:, 0], tf.shape(x)[-1]), tf.tile(tf.range(tf.shape(x)[-1], dtype=tf.int64), [tf.shape(indices)[0]])], axis=-1)
values = tf.reverse(tf.squeeze(tf.gather(x, indices), axis=1), axis = [-1])
tf.tensor_scatter_nd_update(x, new_indices, tf.reshape(values, [-1]))
tf.Tensor(
[[1. 1. 1. 1.]
[1. 1. 1. 0.]
[1. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 1. 0. 0.]
[1. 1. 1. 0.]], shape=(8, 4), dtype=float32)