Home > Software design >  Applying tf.math.argmax on list of tensors
Applying tf.math.argmax on list of tensors

Time:08-29

tf.math.argmax returns index of maximum value in a tensor.

a = tf.constant([1,2,3])
print(a)
print(tf.math.argmax(input = a))

output:

tf.Tensor([1 2 3], shape=(3,), dtype=int32)
<tf.Tensor: shape=(), dtype=int64, numpy=2>

I want to apply tf.math.argmax function on a list of tensors. How can I do it.

input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
print(split_sequence)
tf.math.argmax(input = split_sequence)

output:

[<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>]
tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 1, 1])>

It is giving wrong indices -> numpy=array([1, 1, 1]

desired output:

numpy=array([[2],[2]]

CodePudding user response:

You can use map to apply any function on each value in the list.

(It's better don't use built-in function of python as a variable so I change input to inp)

import tensorflow as tf

inp = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(inp, num_or_size_splits=2, axis=-1)
print(split_sequence)

result = list(map(lambda x: [tf.math.argmax(x).numpy()] , split_sequence))
print(result)

Or by thanks @jkr, we can use List Comprehensions too. (Which one is better, map vs List comprehension)

>>> [[tf.math.argmax(item).numpy()] for item in split_sequence]
[[2], [2]]

[
    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, 
    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>
]

[[2], [2]]

Benchmark (on colab):

import tensorflow as tf
input = tf.constant([1,2,3,4,5,6]*1_000_000)
split_sequence = tf.split(input, num_or_size_splits=20, axis=-1)

%timeit tf.math.top_k(split_sequence, k=1).indices
# 13.5 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


%timeit list(map(lambda x: [tf.math.argmax(x).numpy()] , split_sequence))
# 14 ms ± 2.39 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


%timeit [[tf.math.argmax(item).numpy()] for item in split_sequence]
# 8.77 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

CodePudding user response:

I would recommend simply using tf.math.top_k in your case:

import tensorflow as tf

input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
x = tf.math.top_k(split_sequence, sorted=False, k=1).indices
print(x)
tf.Tensor(
[[2]
 [2]], shape=(2, 1), dtype=int32)

Afterwards, if you want a Numpy array, just call x.numpy().

CodePudding user response:

<tf.Tensor: shape=(), dtype=int64, numpy=2>

You can see the output in numpy = 2 i.e, 2nd index of your constant which is value 3

  • Related