Home > Enterprise >  Efficient pooling operation in Tensorflow : Custom pooling layer
Efficient pooling operation in Tensorflow : Custom pooling layer

Time:05-07

I wish to create a custom pooling layer which can efficiently work on GPUs.

For instance, I have following input tensor

in = <tf.Tensor: shape=(4, 5), dtype=float32, numpy=
array([[0., 1., 2., 3., 4.],
       [5., 1., 7., 3., 2.],
       [9., 9., 2., 3., 5.],
       [2., 6., 2., 8., 4.]], dtype=float32)>

I wish to provide a list of coloumn numbers over which I wish to perform pooling, for instance, I wish to perform max pooling over following column indices

pool_cols =  
[<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4], dtype=int32)>]

And the resultant pooled output will look like

pooled_out = <tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[1., 4.],
       [5., 7.],
       [9., 5.],
       [6., 8.]], dtype=float32)>

What would be the most efficient way to do this?

CodePudding user response:

IIUC, you could try something like this using only tf operations, but I'm not sure how efficient that will be on the GPU:

import tensorflow as tf

tensor = tf.constant([[0., 1., 2., 3., 4.],
                      [5., 1., 7., 3., 2.],
                      [9., 9., 2., 3., 5.],
                      [2., 6., 2., 8., 4.]])


pool_cols = [tf.constant([0, 1]), tf.constant([2, 3, 4])]

def column_max_pooling(tensor, pool_cols):
  results = []
  tensor_shape = tf.shape(tensor)
  for col in pool_cols:
    col_shape = tf.shape(col)
    t = tf.gather_nd(tensor, tf.transpose(tf.stack([tf.tile(tf.range(tensor_shape[0]), [col_shape[0]]), tf.repeat(col, [tensor_shape[0]])])))
    t = tf.reduce_max(tf.transpose(tf.reshape(t, (col_shape[0], tensor_shape[0]))), axis=-1, keepdims=True)
    results.append(t)
  return tf.concat(results, axis=-1)

print(column_max_pooling(tensor, pool_cols))
tf.Tensor(
[[1. 4.]
 [5. 7.]
 [9. 5.]
 [6. 8.]], shape=(4, 2), dtype=float32)

If you can guarantee the order of pool_cols, you could also try using tf.math.unsorted_segment_max:

import tensorflow as tf

tensor = tf.constant([[0., 1., 2., 3., 4.],
                      [5., 1., 7., 3., 2.],
                      [9., 9., 2., 3., 5.],
                      [2., 6., 2., 8., 4.]])

pool_cols = [tf.constant([0, 1]), tf.constant([2, 3, 4])]
result = tf.transpose(tf.math.unsorted_segment_max(tf.transpose(tensor), tf.concat([tf.repeat(idx, tf.shape(col)[0])for idx, col in enumerate(pool_cols)], axis=0), num_segments=len(pool_cols)))
print(result)
tf.Tensor(
[[1. 4.]
 [5. 7.]
 [9. 5.]
 [6. 8.]], shape=(4, 2), dtype=float32)
  • Related