I have this class that inherits from tf.keras.Model
:
import tensorflow as tf
from tensorflow.keras.layers import Dense
class Actor(tf.keras.Model):
def __init__(self):
super().__init__()
self.linear1 = Dense(128, activation = 'relu')
self.linear2 = Dense(256, activation = 'relu')
self.linear3 = Dense(3, activation = 'softmax')
# model override method
def call(self, state):
x = tf.convert_to_tensor(state)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
it is used like this:
prob = self.actor(np.array([state]))
it works with (5,) input and returns (1,3) tensor which is what I expect
state: (5,) data: [0.50267935 0.50267582 0.50267935 0.50268406 0.5026817 ]
prob: (1, 3) data: tf.Tensor([[0.29540768 0.3525798 0.35201252]], shape=(1, 3), dtype=float32)
however if I pass a higher dimension input this return higher dimension tensor:
state: (5, 3) data: [[0.50789109 0.49648439 0.49651666]
[0.5078905 0.49648391 0.49648928]
[0.50788815 0.49648356 0.49643452]
[0.50788677 0.4964834 0.49640713]
[0.50788716 0.49648329 0.49635237]]
prob: (1, 5, 3) data: tf.Tensor(
[[[0.34579638 0.342928 0.3112757 ]
[0.34579614 0.34292707 0.31127676]
[0.34579575 0.34292522 0.31127906]
[0.3457955 0.3429243 0.31128016]
[0.34579512 0.34292242 0.3112824 ]]], shape=(1, 5, 3), dtype=float32)
But I need it to be (1,3) still. I never used raw keras models implemented like this. What can I do to fix it?
Tensorflow 2.9.1 with keras 2.9.0
CodePudding user response:
Looks like you are working on a reinforcement learning problem. Try adding a Flatten
layer to the beginning of your model (or a Reshape
layer):
class Actor(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
self.linear1 = Dense(128, activation = 'relu')
self.linear2 = Dense(256, activation = 'relu')
self.linear3 = Dense(3, activation = 'softmax')
# model override method
def call(self, state):
x = tf.convert_to_tensor(state)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
Also check the design of the Dense
layer:
Note: If the input to the layer has a rank greater than 2, then Dense computes the dot product between the inputs and the kernel along the last axis of the inputs and axis 0 of the kernel (using tf.tensordot). For example, if input has dimensions (batch_size, d0, d1), then we create a kernel with shape (d1, units), and the kernel operates along axis 2 of the input, on every sub-tensor of shape (1, 1, d1) (there are batch_size * d0 such sub-tensors). The output in this case will have shape (batch_size, d0, units).