I have this code that takes a tensor with a shape of (3, 3)
and reshapes it to (9,)
. After that it applies a one_hot
function but it throws an error.
This is the code:
import tensorflow as tf
t1 = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32)
t2 = tf.constant([[1], [-1], [1]], dtype=tf.float32)
print(tf.one_hot(tf.reshape(t1, -1), depth=2))
And the error is :
InvalidArgumentError: Value for attr 'TI' of float is not in the list of allowed values: uint8, int32, int64
; NodeDef: {{node OneHot}}; Op<name=OneHot; signature=indices:TI, depth:int32, on_value:T, off_value:T -> output:T; attr=axis:int,default=-1; attr=T:type; attr=TI:type,default=DT_INT64,allowed=[DT_UINT8, DT_INT32, DT_INT64]> [Op:OneHot]
I'm working in a GoogleColab notebook, so I think that the problem might be the version of TensorFlow or the data types of the tensor, but any other solutions would be appreciated.
CodePudding user response:
You could simply cast your tensor to tf.int32
or similar, since tf.one_hot
expects integer indices
:
import tensorflow as tf
t1 = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32)
t2 = tf.constant([[1], [-1], [1]], dtype=tf.float32)
print(tf.one_hot(tf.cast(tf.reshape(t1, -1), dtype=tf.int32), depth=3))
tf.Tensor(
[[0. 1. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]], shape=(9, 3), dtype=float32)
Or with depth=2
:
tf.Tensor(
[[0. 1.]
[1. 0.]
[1. 0.]
[1. 0.]
[0. 1.]
[1. 0.]
[1. 0.]
[1. 0.]
[0. 1.]], shape=(9, 2), dtype=float32)