I would like to index the maximum value (onehot encoder) of an instance of inputs in the keras model. I know how to preprocess it but I would like to execute it in the model.
inputs = [[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]]
argmax =[3,2,1]
onehot = [[0,0,0,1],
[0,0,1,0],
[0,1,0,0]]
concated= [[ 1,2,3,5, 0,0,0,1],
[ 2,4,6,3, 0,0,1,0],
[ 0,7,1,2, 0,1,0,0]]
inputs = Input()
argmax = K.argmax(inputs, axis=-1)
onehot = #How to perform this operation with keras ??#
concated = Concat([inputs,onehot], axis=-1)
EDITED Would it be possible something like that code in python?
argmax = [3,2,1]
eye = np.eye(4)
eye[argmax]
Out[6]:
array([[0., 0., 0., 1.],
[0., 0., 1., 0.],
[0., 1., 0., 0.]])
CodePudding user response:
You could try something like this:
import tensorflow as tf
x = tf.constant([[
[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]],
[
[ 9,2,3,5],
[ 2,2,2,3],
[ 0,1,5,2]]
])
inputs = tf.keras.layers.Input((3, 4))
argmax = tf.argmax(inputs, axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(tf.shape(inputs)[0], dtype=tf.int64),
tf.range(tf.shape(inputs)[1], dtype=tf.int64),
indexing='ij'), axis=-1)
gather_indices = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
onehot = tf.tensor_scatter_nd_update(tf.zeros_like(x, dtype=tf.float32), gather_indices, tf.ones_like(argmax, dtype=tf.float32))
outputs = tf.keras.layers.Concatenate()([inputs,onehot])
model = tf.keras.Model(inputs, outputs)
print(model(x))
tf.Tensor(
[[[1. 2. 3. 5. 0. 0. 0. 1.]
[2. 4. 6. 3. 0. 0. 1. 0.]
[0. 7. 1. 2. 0. 1. 0. 0.]]
[[9. 2. 3. 5. 1. 0. 0. 0.]
[2. 2. 2. 3. 0. 0. 0. 1.]
[0. 1. 5. 2. 0. 0. 1. 0.]]], shape=(2, 3, 8), dtype=float32)
An explanation to what tf.meshgrid
is and why it is needed can be found here. You can consider wrapping these operations in a custom Keras
layer.
If you would use tf.reduce_max
instead of tf.argmax
, it would be way simpler:
import tensorflow as tf
x = tf.constant([[
[ 1,2,3,5],
[ 2,4,6,3],
[ 0,7,1,2]],
[
[ 9,2,3,5],
[ 2,2,2,3],
[ 0,1,5,2]]
])
inputs = tf.keras.layers.Input((3, 4))
onehot = tf.where(tf.not_equal(inputs, tf.reduce_max(inputs, keepdims=True, axis=-1)), tf.zeros_like(inputs), tf.ones_like(inputs))
outputs = tf.keras.layers.Concatenate()([inputs,onehot])
model = tf.keras.Model(inputs, outputs)
print(model(x))
tf.Tensor(
[[[1. 2. 3. 5. 0. 0. 0. 1.]
[2. 4. 6. 3. 0. 0. 1. 0.]
[0. 7. 1. 2. 0. 1. 0. 0.]]
[[9. 2. 3. 5. 1. 0. 0. 0.]
[2. 2. 2. 3. 0. 0. 0. 1.]
[0. 1. 5. 2. 0. 0. 1. 0.]]], shape=(2, 3, 8), dtype=float32)
But I guess there is a reason you want to use argmax
.