Home > Enterprise >  How to randomly select from set of functions in TensorFlow using tf.function
How to randomly select from set of functions in TensorFlow using tf.function


My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset and tf.function API.

Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions. I would like to write this code within a tf.function so this limits the use of packages like numpy and list indexing.

For example, I would like to do something like this:

import tensorflow as tf

def func1(tensor):
    # Apply some rotation here

def func2(tensor):


def func24(tensor):

def apply(tensor):
    list_of_funcs = [func1, func2, ..., func24]

    # Randomly sample from 0-23
    a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
    return list_of_funcs[a](tensor)

However I cannot index the list_of_funcs as TypeError: list indices must be integers or slices, not Tensor. Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor and use tf.gather.

So my question: how can I reasonably and neatly sample from these functions in a tf.function?

CodePudding user response:

Maybe try using tf.py_function, which:

Wraps a python function into a TensorFlow op that executes it eagerly.

For example:

import tensorflow as tf
import random

def func1(tensor):
    return tensor

def func2(tensor):
    return tensor

def func3(tensor):
    return tensor

def func4(tensor):
    return tensor

def apply(tensor):
    dispatcher = {
        'func1': func1,
        'func2': func2,
        'func3': func3,
        'func4': func4
    keys = list(dispatcher)
    def get_random_function_and_apply(t):
      return dispatcher[random.choice(keys)](t)

    y = tf.py_function(func=get_random_function_and_apply, inp=[tensor], Tout=tf.float32)
    return y
print(apply(tf.random.normal((5, 5, 5))))

[[[ 0.6041213  -2.054427    1.1755397  -0.62914884 -0.00978021]
  [ 0.06134182 -1.5529596  -0.3429052  -0.03199977 -1.1796658 ]
  [-0.65084136 -1.5009187  -0.43266404 -0.18494445  1.2958355 ]
  [-1.6614605  -0.7398612   1.5384725  -0.24926051 -0.5075399 ]
  [ 0.7781286  -0.4102168   1.2152135   0.4508075  -1.7295381 ]]

 [[-1.0509509  -1.271087    1.9061071   0.61855525  0.58581835]
  [ 2.080663    0.43406835  0.32372198 -0.71427256  0.04448809]
  [-0.6438594  -1.1245041  -0.4723388  -0.8302859  -2.0056007 ]
  [ 1.1778332   0.2977344   0.7516829   1.1387901  -0.71768486]
  [-0.44642782 -0.6523012  -0.48157197 -0.8197472   0.3635474 ]]

 [[-0.43357274  1.166849   -0.04528571  0.44322303  0.74193203]
  [ 1.2332342   0.07857647  1.3399298   0.62153     1.835202  ]
  [ 0.48021084  0.36239776  0.16630112  0.59010863  1.8134127 ]
  [-1.1444335   1.2445287  -1.2320557   0.08095992 -0.1379302 ]
  [-1.101756   -1.8099649   0.18504284  0.15212883  0.33380997]]

 [[-0.68228734 -0.82357454 -0.744171   -0.04959428 -1.3200126 ]
  [ 0.813062    1.0669035  -0.7924809  -0.0548021   0.8043163 ]
  [ 1.6480085  -0.17134379  0.25517386  0.02731211  1.2226027 ]
  [-1.9785942  -0.22399756 -0.6814836   1.2065881  -1.7922156 ]
  [-0.34833568 -1.0567352   1.5795225   0.14899854  0.5924402 ]]

 [[-1.057639   -1.1659449  -0.22045298  0.39324322 -1.3500952 ]
  [-0.32044935  0.9534627   0.40809664 -1.0296333  -0.8129102 ]
  [-0.13515176 -0.32676768 -0.9333701   0.35130095 -1.5411847 ]
  [ 2.090785    0.3497966   0.27694222  0.78199005 -0.08591356]
  [ 0.9621986  -2.3930101  -1.1035724   0.27208164 -1.1846163 ]]], shape=(5, 5, 5), dtype=float32)


CodePudding user response:

You can use a bunch of nested tf.cond. If a condition is met, it will call either the true_fn or the false_fn. Since you have more than two functions, you can nest them for as many functions as you like. For instance, I'm making functions that multiply the input by either 2, 3, 4 or 5, depending on the value of a random variable.

import tensorflow as tf

x = 10

def mult_2():
    tf.print(f'i was 2, returning {x} multiplied by 2')
    return tf.multiply(x, 2)

def mult_3():
    tf.print(f'i was 3, returning {x} multiplied by 3')
    return tf.multiply(x, 3)

def mult_4():
    tf.print(f'i was 4, returning {x} multiplied by 4')
    return tf.multiply(x, 4)

def mult_5():
    tf.print(f'i was 5, returning {x} multiplied by 5')
    return tf.multiply(x, 5)

i = tf.random.uniform((), 1, 5, dtype=tf.int32)

tf.cond(i == 2, mult_2,
        lambda: tf.cond(i == 3, mult_3,
                        lambda: tf.cond(i == 4, mult_4, mult_5)))
i was 3, returning 10 multiplied by 3
<tf.Tensor: shape=(), dtype=int32, numpy=30>

Note that mult_5 will execute if none of the conditions are met.

  • Related