How can I do the same operation in tensorflow
?
tensor = np.random.RandomState(42).uniform(size=(2, 4, 2)).astype(np.float32)
tensor = torch.from_numpy(tensor)
index = tensor.max(-1, keepdim=True)[1]
output = torch.zeros_like(tensor).scatter_(-1, index, 1.0)
expected output:
tensor([[[0., 1.],
[1., 0.],
[1., 0.],
[0., 1.]],
[[0., 1.],
[0., 1.],
[1., 0.],
[0., 1.]]])
CodePudding user response:
As always, everything is a bit more complicated with Tensorflow:
import tensorflow as tf
import numpy as np
tensor = np.random.RandomState(42).uniform(size=(2, 4, 2)).astype(np.float32)
tensor = tf.constant(tensor)
_, indices = tf.math.top_k(tensor)
zeros = tf.zeros_like(tensor)
ij = tf.stack(tf.meshgrid(
tf.range(zeros.shape[0], dtype=tf.int32),
tf.range(zeros.shape[1], dtype=tf.int32),
indexing='ij'), axis=-1)
gathered_indices = tf.concat([ij, indices], axis=-1)
indices_shape = tf.shape(indices)
values = tf.ones((indices_shape[0], indices_shape[1]))
output = tf.tensor_scatter_nd_update(zeros, gathered_indices, values)
print(output)
tf.Tensor(
[[[0. 1.]
[1. 0.]
[1. 0.]
[0. 1.]]
[[0. 1.]
[0. 1.]
[1. 0.]
[0. 1.]]], shape=(2, 4, 2), dtype=float32)
CodePudding user response:
So, here's a solution I ended up using.
tensor = np.random.RandomState(42).uniform(size=(2, 4, 2)).astype(np.float32)
tensor = tf.convert_to_tensor(tensor, dtype=tf.float32)
fill_value = 1.0
indices = tf.argmax(tensor, axis=-1)
depth = tensor.shape[-1]
output = tf.cast(tf.one_hot(indices, depth, on_value=fill_value), dtype=tf.float32)
tf.Tensor(
[[[0. 1.]
[1. 0.]
[1. 0.]
[0. 1.]]
[[0. 1.]
[0. 1.]
[1. 0.]
[0. 1.]]], shape=(2, 4, 2), dtype=float32)