The problem is, I have an indices tensor with shape [batch_size, seq_len, k]
and every element in this tensor is in range [0, hidden_dim)
. I want to create a mask tensor with shape [batch_size, seq_len, hidden_dim]
where every element indexed by the indices
tensor is 1
and other elements are 0
. k
is smaller than hidden_dim
. For example:
indices = [[[0],[1],[2]]] #batch_size=1, seq_len=3, k=1
mask = tf.zeros(shape=(1,3,3)) #batch_size=1, seq_len=3, hidden_dim = 3
How can I get a target mask tensor whose elements indicated by the indices
are 1
, i.e.:
target_mask = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]
CodePudding user response:
This can be accomplished using tf.one_hot
, e.g.:
mask = tf.one_hot(indices, depth=hidden_dim, axis=-1) # [batch, seq_len, k, hidden_dim]
I wasn't clear on what you'd like to happen to k
. tf.one_hot()
will keep the axis as is, i.e. you'll get a delta distribution for each [batch-index, seq-index, k-index] tuple.