currently I am writing a function for tensorflow to calculate states for a board game. I tested the function in eager-mode and decorated it with @tf.function
in order to speed up. On my laptop, the functions works as expected in both cases, but as soon as I switch to the server (for both, GPU and CPU usage) an error occurs stating that the array dimensions do not fit.
Here is the part of the code, that results in an error:
TF_MOVEDIRECTIONS = tf.constant(
[[1, 0], [1, 1], [0, 1], [-1, 0], [-1, -1], [0, -1]],
dtype=tf.int64)
@tf.function
def tf_calculations_nonlosing(bstate):
new_states = tf.TensorArray(dtype=tf.int64,
size=0,
dynamic_size=True,
infer_shape=True,
clear_after_read=False)
ns_idx = tf.constant(0, dtype=tf.int32)
# bstate is a tf.Tensor with shape = (11, 11)
empty = tf.where(bstate == 0)
whites = tf.where(bstate == 1) # the shape does not fit?
blacks = tf.where(bstate == 2)
for move_idx in tf.range(tf.shape(TF_MOVEDIRECTIONS)[0]):
md = TF_MOVEDIRECTIONS[move_idx]
new_pos_w1 = tf.expand_dims(whites md, axis=1) # error results from this line
...
The error raised is:
ValueError: Dimensions must be equal, but are 0 and 2 for '{{node while/add_1}} = AddV2[T=DT_INT64](while/add_1/Where_1, while/strided_slice_1) with input shapes: [?, 0] and [2]
The strange thing is: When I test whites = tf.where(bstate == 1)
the resulting shape always is [?, 2]. And as written above, the code works on my laptop in a jupyter notebook. Therefore, I do not understand, what the source of the error is. I compared the tensorflow versions and they are both 2.7.0. Now I have no clue, what else might be the source of this error, and I did not even find a point where to start.
Did anybody encounter a similar error? Or does someone have an idea how to fix this behavior
CodePudding user response:
Is your question is that why output shape is None?
The 'tf.where' output shape is [the number of True, input's rank]. for example,
# (, 4): input's total rank is 1
input = [True, False, False, True]
# (2, 1): (True 2, input rank 1)
output = tf.where(input).numpy()
array([[0], [3]])
If you want to know in detail, please refer to this. https://www.tensorflow.org/api_docs/python/tf/where
but when you use a TensorFlow function, the eager execution changes into TensorFlow graph operation. https://www.tensorflow.org/api_docs/python/tf/function
So, although True not in bstate.shape is [11, 11], the output shape will be None that means a dynamic shape.
CodePudding user response:
I found the issue finally on my own. It had nothing to do with the function at all.
I just messed up the input and that caused this error in the second place. I did not find this, because the other functions, which were called before, managed to run with this erroneous input, so I failed to check this for quite a while.