Home > Blockchain >  How to reverse a specific dimension/sample of a tensor only when a per-dimension/sample condition is
How to reverse a specific dimension/sample of a tensor only when a per-dimension/sample condition is

Time:08-04

In short, how do I translate this tensor:

[[1. 1. 1. 1.]
 [1. 1. 1. 0.]
 [1. 1. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 1. 1.]
 [0. 1. 1. 1.]]

into this one:

[[1. 1. 1. 1.]
 [1. 1. 1. 0.]
 [1. 1. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 1. 0. 0.]
 [1. 1. 1. 0.]]

If the input tensor's 0th value on dimension 1 == 0, I want to reverse the entire dimension 1 of that specific sample; otherwise, leave it alone. I am actually working with a 3-dimensional tensor (batch, horizon, feature) but I simplified the example tensors here.

CodePudding user response:

Assuming you tried everything in your power to solve this, here is something that might be useful:

import tensorflow as tf

x = tf.constant([[1., 1., 1., 1.],
                [1., 1., 1., 0.],
                [1., 1., 0., 0.],
                [1., 0., 0., 0.],
                [0., 0., 0., 0.],
                [0., 0., 0., 1.],
                [0., 0., 1., 1.],
                [0., 1., 1., 1.]])

indices = tf.where(x[:, 0] == 0.0)
new_indices = tf.stack([tf.repeat(indices[:, 0], tf.shape(x)[-1]), tf.tile(tf.range(tf.shape(x)[-1], dtype=tf.int64), [tf.shape(indices)[0]])], axis=-1)
values = tf.reverse(tf.squeeze(tf.gather(x, indices), axis=1), axis = [-1])
tf.tensor_scatter_nd_update(x, new_indices, tf.reshape(values, [-1]))
tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 0.]
 [1. 1. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 1. 0. 0.]
 [1. 1. 1. 0.]], shape=(8, 4), dtype=float32)
  • Related