Home > Enterprise >  Reduce max in tensorflow for grouping rows?
Reduce max in tensorflow for grouping rows?

Time:09-28

I have a tensor I want to calculate max in every column grouped by certain rows.

Eg: tensor_eg =

[[0.1 0.2 0.4 0.5],
 [0.1 0.8 0.2 0.5],
 [0.1 0.2 0.4 0.5],
 [0.1 0.1 0.6 0.5]]

tf.reduce_max(tensor_eg, axis = 0) would give me the max value for each column over all rows. I would like to do it grouped by certain rows, say max for row 0 and 1, and max for row 2 and 3:

Wanted result:

[[0.1 0.8 0.4 0.5],
 [0.1 0.2 0.6 0.5]]

How can I achieve this?

CodePudding user response:

Maybe just try slicing and using tf.concat:

import tensorflow as tf

x = tf.constant([[0.1, 0.2, 0.4, 0.5],
                 [0.1, 0.8, 0.2, 0.5],
                 [0.1, 0.2, 0.4, 0.5],
                 [0.1, 0.1, 0.6, 0.5]])

tf.concat([tf.reduce_max(x[:2, :], keepdims=True, axis=0), tf.reduce_max(x[2:, :], keepdims=True, axis = 0)], axis=0) 
<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
array([[0.1, 0.8, 0.4, 0.5],
       [0.1, 0.2, 0.6, 0.5]], dtype=float32)>

A more generic approach would be to use tf.math.segment_max:

x = tf.constant([[0.1, 0.2, 0.4, 0.5],
                 [0.1, 0.8, 0.2, 0.5],
                 [0.1, 0.2, 0.4, 0.5],
                 [0.1, 0.1, 0.6, 0.5]])

tf.math.segment_max(x, tf.constant([0, 0, 1, 1]))

The segment ids need to be equal to the size of x's first dimension

  • Related