Home > Net >  How to mask a tensor's time steps prior to event boolean
How to mask a tensor's time steps prior to event boolean

Time:06-28

I have time-series data in the form of [batch_size, horizon, feature]. Events occur every so often, and I demarcate them in a separate "meta" tensor as a boolean flag. i.e., it's a tensor of the same shape filled with zeros except for when a given event occurs (in which case it's a 1).

I need to be able to prevent my model from looking at data prior to the event if an event has occurred within the horizon; so by default within the 2nd dimension, the mask should be all ones, and timesteps before a detected event should be all zeros.

Only the last event should be considered, and all prior timesteps should be 0 even if there were prior events.

One-dimensional examples (meta -> mask):

[0, 0, 1, 0] -> [0, 0, 1, 1]
[0, 0, 0, 1] -> [0, 0, 0, 1]
[1, 0, 1, 0] -> [0, 0, 1, 1]
[1, 0, 0, 0] -> [1, 1, 1, 1]
[0, 0, 0, 0] -> [1, 1, 1, 1]

CodePudding user response:

Maybe something like this:

# copy, paste, acknowledge

import tensorflow as tf

the_example = tf.constant([[0, 0, 1, 0], 
                           [0, 0, 0, 1], 
                           [1, 0, 1, 0], 
                           [1, 0, 0, 0],
                           [0, 0, 0, 0]]) 

the_zero_mask = tf.where(tf.reduce_all(the_example == 0, axis=-1), True, False)
x = tf.boolean_mask(the_example, ~the_zero_mask)
this_shape = tf.shape(x)

something_special = tf.stack([tf.repeat(tf.where(~the_zero_mask), this_shape[-1]), tf.cast(tf.tile(tf.range(this_shape[-1]), [this_shape[0]]), dtype=tf.int64)], axis=-1)
tell_me_where = tf.where(x == 1)
here = tf.math.unsorted_segment_max(data = tell_me_where[:, 1], segment_ids = tell_me_where[:, 0], num_segments=this_shape[0])
raggidy_ragged = tf.reverse(tf.ones_like(tf.ragged.range(here, this_shape[-1])).to_tensor(), axis=[-1])
we_made_it = tf.tensor_scatter_nd_update(tf.ones_like(the_example, dtype=tf.int64), something_special, tf.reshape(raggidy_ragged, [-1]))
print(we_made_it)
tf.Tensor(
[[0 0 1 1]
 [0 0 0 1]
 [0 0 1 1]
 [1 1 1 1]
 [1 1 1 1]], shape=(5, 4), dtype=int64)
  • Related