Home > Net >  creating a mask tensor from an index tensor
creating a mask tensor from an index tensor

Time:08-09

The problem is, I have an indices tensor with shape [batch_size, seq_len, k] and every element in this tensor is in range [0, hidden_dim). I want to create a mask tensor with shape [batch_size, seq_len, hidden_dim] where every element indexed by the indices tensor is 1 and other elements are 0. k is smaller than hidden_dim. For example:

indices = [[[0],[1],[2]]] #batch_size=1, seq_len=3, k=1
mask = tf.zeros(shape=(1,3,3)) #batch_size=1, seq_len=3, hidden_dim = 3

How can I get a target mask tensor whose elements indicated by the indices are 1, i.e.:

target_mask = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]

CodePudding user response:

This can be accomplished using tf.one_hot, e.g.:

mask = tf.one_hot(indices, depth=hidden_dim, axis=-1)  # [batch, seq_len, k, hidden_dim]

I wasn't clear on what you'd like to happen to k. tf.one_hot() will keep the axis as is, i.e. you'll get a delta distribution for each [batch-index, seq-index, k-index] tuple.

  • Related