Home > front end >  Tensorflow Value for attr 'TI' of float is not in the list of allowed values when One Hot
Tensorflow Value for attr 'TI' of float is not in the list of allowed values when One Hot

Time:11-13

I have this code that takes a tensor with a shape of (3, 3) and reshapes it to (9,). After that it applies a one_hot function but it throws an error.

This is the code:

import tensorflow as tf

t1 = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32)
t2 = tf.constant([[1], [-1], [1]], dtype=tf.float32)

print(tf.one_hot(tf.reshape(t1, -1), depth=2))

And the error is :

InvalidArgumentError: Value for attr 'TI' of float is not in the list of allowed values: uint8, int32, int64
    ; NodeDef: {{node OneHot}}; Op<name=OneHot; signature=indices:TI, depth:int32, on_value:T, off_value:T -> output:T; attr=axis:int,default=-1; attr=T:type; attr=TI:type,default=DT_INT64,allowed=[DT_UINT8, DT_INT32, DT_INT64]> [Op:OneHot]

I'm working in a GoogleColab notebook, so I think that the problem might be the version of TensorFlow or the data types of the tensor, but any other solutions would be appreciated.

CodePudding user response:

You could simply cast your tensor to tf.int32 or similar, since tf.one_hot expects integer indices:

import tensorflow as tf

t1 = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32)
t2 = tf.constant([[1], [-1], [1]], dtype=tf.float32)

print(tf.one_hot(tf.cast(tf.reshape(t1, -1), dtype=tf.int32), depth=3))
tf.Tensor(
[[0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]], shape=(9, 3), dtype=float32)

Or with depth=2:

tf.Tensor(
[[0. 1.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [0. 1.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [0. 1.]], shape=(9, 2), dtype=float32)
  • Related