I have a tensor I want to calculate max in every column grouped by certain rows.
Eg: tensor_eg =
[[0.1 0.2 0.4 0.5],
[0.1 0.8 0.2 0.5],
[0.1 0.2 0.4 0.5],
[0.1 0.1 0.6 0.5]]
tf.reduce_max(tensor_eg, axis = 0) would give me the max value for each column over all rows. I would like to do it grouped by certain rows, say max for row 0 and 1, and max for row 2 and 3:
Wanted result:
[[0.1 0.8 0.4 0.5],
[0.1 0.2 0.6 0.5]]
How can I achieve this?
CodePudding user response:
Maybe just try slicing and using tf.concat
:
import tensorflow as tf
x = tf.constant([[0.1, 0.2, 0.4, 0.5],
[0.1, 0.8, 0.2, 0.5],
[0.1, 0.2, 0.4, 0.5],
[0.1, 0.1, 0.6, 0.5]])
tf.concat([tf.reduce_max(x[:2, :], keepdims=True, axis=0), tf.reduce_max(x[2:, :], keepdims=True, axis = 0)], axis=0)
<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
array([[0.1, 0.8, 0.4, 0.5],
[0.1, 0.2, 0.6, 0.5]], dtype=float32)>
A more generic approach would be to use tf.math.segment_max
:
x = tf.constant([[0.1, 0.2, 0.4, 0.5],
[0.1, 0.8, 0.2, 0.5],
[0.1, 0.2, 0.4, 0.5],
[0.1, 0.1, 0.6, 0.5]])
tf.math.segment_max(x, tf.constant([0, 0, 1, 1]))
The segment ids need to be equal to the size of x
's first dimension