I want to create a custom ReLU function that takes a vector V
that holds 2D locations of pixels i.e [[1, 3], [1,1] ..]
and performs a ReLU operation on those pixels across all channels.
the input_tensor is a tensor after it passes through a Conv2D
layer - so its shape is: (None, 30, 30, 16)
(original image is 32x32x3
)
my code (I know it won't actually return an altered input_tensor, it's just to sort some of the first problems i've encountered) :
def relu(x):
if x > 0:
return x
else:
return 0
class Custom_ReLU(Layer):
def __init__(self):
super(Custom_ReLU, self).__init__()
def call(self, input_tensor, V=None):
for i in range(len(V)):
relu(input_tensor[ V[i] ])
return input_tensor
When running this I get an error that I pass a tensor into relu - so my indexing is incorrect, but I tried indexing it differently and I got nowhere:
ValueError: Exception encountered when calling layer "custom__re_lu" (type Custom_ReLU).
in user code:
File "/home/adav/prog/tf/main.py", line 63, in call *
relu(input_tensor[ V[i][0], V[i][1] ])
File "/home/adav/prog/tf/main.py", line 51, in relu *
if x > 0:
ValueError: condition of if statement expected to be `tf.bool` scalar, got Tensor("my_model/custom__re_lu/Greater:0", shape=(30, 16), dtype=bool); to check for None, use `is not None`
Call arguments received:
• input_tensor=tf.Tensor(shape=(None, 30, 30, 16), dtype=float32)
• V=array([[ 1, 1],
[ 1, 4],
[ 1, 7],
.
.
.
Any help would be appreciated !
CodePudding user response:
You could try using tf.gather_nd
and tf.tensor_scatter_nd_update
:
import tensorflow as tf
class Custom_ReLU(tf.keras.layers.Layer):
def __init__(self):
super(Custom_ReLU, self).__init__()
def call(self, inputs):
shape = tf.shape(inputs)
V = tf.stack([(i, j) for i in tf.range(1,29,3) for j in tf.range(1,29,3)])
indices = tf.concat([tf.expand_dims(tf.repeat(tf.range(0, shape[0]), repeats=tf.shape(V)[0]), axis=-1), tf.tile(V, [shape[0],1])], axis=-1)
y = tf.gather_nd(inputs, indices)
y = tf.where(tf.greater(y, 0.0), y, tf.constant(0.0))
return tf.tensor_scatter_nd_update(x, indices, y)
custom_relu = Custom_ReLU()
x = tf.random.normal((2, 30, 30, 16))
print(custom_relu(x))
Run a few iterations with a smaller tensor to see how the values change.