My problem is this: during pre-processing I want to apply a function randomly selected from a set of functions to dataset examples using the tf.data.Dataset
and tf.function
API.
Specifically, my data are 3D volumes and I wish to apply a rotation from a set of 24 predefined rotation functions. I would like to write this code within a tf.function
so this limits the use of packages like numpy
and list indexing.
For example, I would like to do something like this:
import tensorflow as tf
@tf.function
def func1(tensor):
# Apply some rotation here
...
@tf.function
def func2(tensor):
...
...
@tf.function
def func24(tensor):
...
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, ..., func24]
# Randomly sample from 0-23
a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
return list_of_funcs[a](tensor)
However I cannot index the list_of_funcs
as TypeError: list indices must be integers or slices, not Tensor
. Additionally, I cannot collect these functions (AFAIK) into a tf.Tensor
and use tf.gather
.
So my question: how can I reasonably and neatly sample from these functions in a tf.function
?
CodePudding user response:
Maybe try using tf.py_function, which:
Wraps a python function into a TensorFlow op that executes it eagerly.
For example:
import tensorflow as tf
import random
@tf.function
def func1(tensor):
print('func1')
return tensor
@tf.function
def func2(tensor):
print('func2')
return tensor
@tf.function
def func3(tensor):
print('func3')
return tensor
@tf.function
def func4(tensor):
print('func4')
return tensor
@tf.function
def apply(tensor):
dispatcher = {
'func1': func1,
'func2': func2,
'func3': func3,
'func4': func4
}
keys = list(dispatcher)
def get_random_function_and_apply(t):
return dispatcher[random.choice(keys)](t)
y = tf.py_function(func=get_random_function_and_apply, inp=[tensor], Tout=tf.float32)
return y
print(apply(tf.random.normal((5, 5, 5))))
'''
func4
tf.Tensor(
[[[ 0.6041213 -2.054427 1.1755397 -0.62914884 -0.00978021]
[ 0.06134182 -1.5529596 -0.3429052 -0.03199977 -1.1796658 ]
[-0.65084136 -1.5009187 -0.43266404 -0.18494445 1.2958355 ]
[-1.6614605 -0.7398612 1.5384725 -0.24926051 -0.5075399 ]
[ 0.7781286 -0.4102168 1.2152135 0.4508075 -1.7295381 ]]
[[-1.0509509 -1.271087 1.9061071 0.61855525 0.58581835]
[ 2.080663 0.43406835 0.32372198 -0.71427256 0.04448809]
[-0.6438594 -1.1245041 -0.4723388 -0.8302859 -2.0056007 ]
[ 1.1778332 0.2977344 0.7516829 1.1387901 -0.71768486]
[-0.44642782 -0.6523012 -0.48157197 -0.8197472 0.3635474 ]]
[[-0.43357274 1.166849 -0.04528571 0.44322303 0.74193203]
[ 1.2332342 0.07857647 1.3399298 0.62153 1.835202 ]
[ 0.48021084 0.36239776 0.16630112 0.59010863 1.8134127 ]
[-1.1444335 1.2445287 -1.2320557 0.08095992 -0.1379302 ]
[-1.101756 -1.8099649 0.18504284 0.15212883 0.33380997]]
[[-0.68228734 -0.82357454 -0.744171 -0.04959428 -1.3200126 ]
[ 0.813062 1.0669035 -0.7924809 -0.0548021 0.8043163 ]
[ 1.6480085 -0.17134379 0.25517386 0.02731211 1.2226027 ]
[-1.9785942 -0.22399756 -0.6814836 1.2065881 -1.7922156 ]
[-0.34833568 -1.0567352 1.5795225 0.14899854 0.5924402 ]]
[[-1.057639 -1.1659449 -0.22045298 0.39324322 -1.3500952 ]
[-0.32044935 0.9534627 0.40809664 -1.0296333 -0.8129102 ]
[-0.13515176 -0.32676768 -0.9333701 0.35130095 -1.5411847 ]
[ 2.090785 0.3497966 0.27694222 0.78199005 -0.08591356]
[ 0.9621986 -2.3930101 -1.1035724 0.27208164 -1.1846163 ]]], shape=(5, 5, 5), dtype=float32)
'''
CodePudding user response:
You can use a bunch of nested tf.cond
. If a condition is met, it will call either the true_fn
or the false_fn
. Since you have more than two functions, you can nest them for as many functions as you like. For instance, I'm making functions that multiply the input by either 2, 3, 4 or 5, depending on the value of a random variable.
import tensorflow as tf
x = 10
@tf.function
def mult_2():
tf.print(f'i was 2, returning {x} multiplied by 2')
return tf.multiply(x, 2)
@tf.function
def mult_3():
tf.print(f'i was 3, returning {x} multiplied by 3')
return tf.multiply(x, 3)
@tf.function
def mult_4():
tf.print(f'i was 4, returning {x} multiplied by 4')
return tf.multiply(x, 4)
@tf.function
def mult_5():
tf.print(f'i was 5, returning {x} multiplied by 5')
return tf.multiply(x, 5)
i = tf.random.uniform((), 1, 5, dtype=tf.int32)
tf.cond(i == 2, mult_2,
lambda: tf.cond(i == 3, mult_3,
lambda: tf.cond(i == 4, mult_4, mult_5)))
i was 3, returning 10 multiplied by 3
<tf.Tensor: shape=(), dtype=int32, numpy=30>
Note that mult_5
will execute if none of the conditions are met.