I would like to get the indices of maximum values.
Eg:
[
[
[0.1 0.3 0.6],
[0.0 0.4 0.1]
],
[
[0.9 0.2 0.6],
[0.8 0.1 0.5]
]
]
I would like to get [[0,0,2], [0,1,1], [1,0,0], [1,1,0]]
. How do I do that in the easiest way in Tensorflow?
CodePudding user response:
You can take advantage of TF's broadcast in the last dimension
a = tf.constant([[[0.1, 0.3, 0.6],[0.0, 0.4, 0.1]],[[0.9, 0.2, 0.6],[0.8, 0.1, 0.5]]])
b = tf.reduce_max(a, -1, keepdims=True)
tf.where(a == b)
Output
<tf.Tensor: shape=(4, 3), dtype=int64, numpy=
array([[0, 0, 2],
[0, 1, 1],
[1, 0, 0],
[1, 1, 0]], dtype=int64)>
In case of multiple max values per row and you only want to keep index of the first, you can derive which segment each row in the result corresponds to, then do a segment_min
to get the first index in each segment.
a = tf.constant([[[0.1, 0.6, 0.6],[0.0, 0.4, 0.1]],[[0.9, 0.2, 0.6],[0.8, 0.1, 0.5]]])
b = tf.reduce_max(a, -1, keepdims=True)
c = tf.cast(tf.where(a == b), tf.int32)
d = tf.reduce_sum(tf.math.cumprod(a.shape[:-1], reverse=True, exclusive=True) * c[:,:-1], axis=1)
tf.math.segment_min(c,d)
Output
<tf.Tensor: shape=(4, 3), dtype=int32, numpy=
array([[0, 0, 1],
[0, 1, 1],
[1, 0, 0],
[1, 1, 0]])>
CodePudding user response:
#argmax will give the index but not in the format you want
max_index = tf.reshape(tf.math.argmax(a, -1),(-1, 1))
max_index
<tf.Tensor: shape=(4, 1), dtype=int64, numpy=
array([[2],
[1],
[0],
[0]])>
#Format output
idx_axis =tf.reshape(tf.Variable(np.indices((a.shape[0],a.shape[1])).transpose(1,2,0)), (-1,a.shape[1]))
idx_axis
<tf.Tensor: shape=(4, 2), dtype=int64, numpy=
array([[0, 0],
[0, 1],
[1, 0],
[1, 1]])>
tf.concat([idx_axis,max_index], axis=1)
<tf.Tensor: shape=(4, 3), dtype=int64, numpy=
array([[0, 0, 2],
[0, 1, 1],
[1, 0, 0],
[1, 1, 0]])>