My tensorflow version is 2.4.0-rc0.
Afer I debug my code, my code seems have one problem as below:
@tf.function
def keypoint_distance(self, kpt):
if kpt[2] == tf.constant(0.0):
return tf.ones((self.LABEL_HEIGHT, self.LABEL_WIDTH), dtype=tf.float32)
else:
ortho_dist = self.grid - kpt[0:2]
return tf.linalg.norm(ortho_dist, axis=-1)
Where did I treat a Tensor as a bool?
The all message I get:
raise errors.OperatorNotAllowedInGraphError(
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
CodePudding user response:
You can't use Python bool
in graph mode. You should instead use tf.cond
. It will return a function depending on a condition.
import tensorflow as tf
@tf.function
def keypoint_distance(kpt):
return tf.cond(pred=kpt[2] == tf.constant(0.0),
true_fn=lambda: tf.ones((3, 4), dtype=tf.float32),
false_fn=lambda: tf.linalg.norm(kpt[0:2], axis=-1))
keypoint_distance([1., 2., 3.]) # will be false
keypoint_distance([1., 2., 0.]) # will be true
Note that I removed objects that weren't defined in your code sample.