I would like to get the value with the maximum absolute value in a tensor, with respect to an axis. Note that I don't want the maximum absolute value, I want the value that has the maximum absolute value (so I need to keep the sign).
Ideally, I would like something similar to reduce_max
or reduce_min
:
tensor = tf.constant(
[
[[ 1, 5, -3],
[ 2, -3, 1],
[ 3, -6, 2]],
[[-2, 3, -5],
[-1, 4, 2],
[ 4, -1, 0]]
]
)
# tensor.shape = (2, 3, 3)
tensor.reduce_maxamplitude(tensor, axis=0)
# Tensor(
# [[-2, 5, -5],
# [ 2, 4, 2],
# [ 4, -6, 2]]
# )
# shape: (3, 3)
tensor.reduce_maxamplitude(tensor, axis=1)
# Tensor(
# [[3, -6, -3],
# [4, 4, -5]]
# )
# shape: (2, 3)
tensor.reduce_maxamplitude(tensor, axis=2)
# Tensor(
# [[5, -3, -6],
# [-5, 4, 4]]
# )
# shape: (2, 3)
but I did not find anything useful in tensorflow documentation.
With a flat tensor, I know that I could use tf.foldl
or tf.foldr
:
flat = tf.reshape(tensor, -1)
tf.foldr(lambda a, x: x if tf.abs(x) > tf.abs(a) else a, flat)
# -6
However, I don't know how to handle an axis parameter in the case of multidimensional tensors.
CodePudding user response:
It really depends on how many dimensions your tensor has, but for a 2D tensor you could just do:
import tensorflow as tf
tensor = tf.constant(
[[1, 5, -3],
[2, -3, 1],
[3, -6, 2]])
tf.gather(tensor, tf.argmax(tf.abs(tensor), axis=1), axis=1, batch_dims=1)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 5, -3, -6], dtype=int32)>
3D example:
tensor = tf.constant(
[
[[ 1, 5, -3],
[ 2, -3, 1],
[ 3, -6, 2]],
[[-2, 3, -5],
[-1, 4, 2],
[ 4, -1, 0]]
]
)
# axis = 0
argmax = tf.argmax(tf.abs(tensor), axis=0)
i, j = tf.meshgrid(
tf.range(tensor.shape[1], dtype=tf.int64),
tf.range(tensor.shape[2], dtype=tf.int64),
indexing='ij')
tf.gather_nd(tensor, tf.stack([argmax, i, j], axis=-1))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[-2, 5, -5],
[ 2, 4, 2],
[ 4, -6, 2]], dtype=int32)>
# axis = 1
argmax = tf.argmax(tf.abs(tensor), axis=1)
i, j = tf.meshgrid(
tf.range(tensor.shape[0], dtype=tf.int64),
tf.range(tensor.shape[2], dtype=tf.int64),
indexing='ij')
tf.gather_nd(tensor, tf.stack([i, argmax, j], axis=-1))
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 3, -6, -3],
[ 4, 4, -5]], dtype=int32)>
# axis = 2
i, j = tf.meshgrid(
tf.range(tensor.shape[0], dtype=tf.int64),
tf.range(tensor.shape[1], dtype=tf.int64),
indexing='ij')
tf.gather_nd(tensor, tf.stack([i, j, argmax], axis=-1))
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 5, -3, -6],
[-5, 4, 4]], dtype=int32)>
For a 4D tensor just extend the meshgrid
:
# axis=-1
i, j, k = tf.meshgrid(
tf.range(tensor.shape[0], dtype=tf.int64),
tf.range(tensor.shape[1], dtype=tf.int64),
tf.range(tensor.shape[2], dtype=tf.int64),
indexing='ij')
Quick function bundling everything by @leleogere
def reduce_maxamplitude(tensor, axis):
argmax = tf.argmax(tf.abs(tensor), axis=axis)
mesh = tf.meshgrid(
*[tf.range(tensor.shape[i], dtype=tf.int64) for i in range(tensor.shape.rank) if i != axis],
indexing='ij'
)
return tf.gather_nd(tensor, tf.stack([*mesh[:axis], argmax, *mesh[axis:]], axis=-1))