As a minimal example, say I have a tensor of the form:
[[ 1. 0. 3. ]
[ 7. 5. 6. ]
[ 0. 0. 0. ]
[ 0. 11. 1. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[13. 14. 16.5]]
Is there a way (natively in tensorflow) to impute the fully zeroed rows such their values are assigned to be equal to the last non-fully zeroed row? I.e. ->:
[[ 1. 0. 3. ]
[ 7. 5. 6. ]
[ 7. 5. 6. ]
[ 0. 11. 1. ]
[ 0. 11. 1. ]
[ 0. 11. 1. ]
[13. 14. 16.5]]
I thought about using tf.tensor_scatter_nd_update
but with no success.
CodePudding user response:
This code can also be run on GPU too.
data = tf.constant([[ 1., 0., 3. ],
[ 7., 5., 6. ],
[ 0., 0., 0. ],
[ 0., 11., 1. ],
[ 0., 0., 0. ],
[ 0., 0., 0. ],
[13., 14., 16.5]])
rows_length = data.shape[-1]
i = tf.constant(0)
c = lambda i: tf.less(i, len(data))
tf.while_loop(c, find_zeros_and_update, [i])
def find_zeros_and_update(i):
global _data
if(i == 0):
_data = data
if(tf.reduce_sum(_data[i]) == 0):
rows = tf.ones(shape=(rows_length,1), dtype=tf.int32) (i-1)
columns = tf.split(tf.range(0,rows_length),rows_length)
indices = tf.concat((rows , columns), axis=1)
update = _data[i-1]
_data = tf.tensor_scatter_nd_update(_data, indices, update,)
return (tf.add(i,1),)
Output:
<tf.Tensor: shape=(7, 3), dtype=float32, numpy=
array([[ 1. , 0. , 3. ],
[ 7. , 5. , 6. ],
[ 7. , 5. , 6. ],
[ 0. , 11. , 1. ],
[ 0. , 11. , 1. ],
[ 0. , 11. , 1. ],
[13. , 14. , 16.5]], dtype=float32)>
CodePudding user response:
We can use tf.gather(a, indices)
to get the above output.
The indices
needs to be [0, 1, 1, 3, 3, 3, 6]
which can be be obtained with the following code:
mask = tf.cast(tf.cast(tf.reduce_sum(a, axis=1), dtype=tf.bool), tf.float32)
#[1., 1., 0., 1., 0., 0., 1.] where non-zero sum
mask_range = (mask*tf.range(a.shape[0], dtype=tf.float32))
#[0., 1., 0., 3., 0., 0., 6.] apply mask on range()
indices =tf.cast(tf.scan(lambda a, b: tf.maximum(a, b), mask_range, initializer=tf.reduce_min(mask_range)), tf.int32)
# cumulative max [0, 1, 1, 3, 3, 3, 6]
tf.gather(a, indices)
[[ 1. , 0. , 3. ], [ 7. , 5. , 6. ], [ 7. , 5. , 6. ], [ 0. , 11. , 1. ], [ 0. , 11. , 1. ], [ 0. , 11. , 1. ], [13. , 14. , 16.5]]