I am training a VQVAE with this dataset (64x64x3). I have downloaded it locally and loaded it with keras in Jupyter notebook. The problem is that when I ran fit()
to train the model I get this error: ValueError: Layer "vq_vae" expects 1 input(s), but it received 2 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 64, 64, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int32>]
. I have taken most of the code from here and adapted it myself. But for some reason I can't make it work for other datasets. You can ignore most of the code here and check it in the page, help is much appreciated.
The code I have so far:
img_height = 64
img_width = 64
dataset = tf.keras.utils.image_dataset_from_directory(directory="PATH",
image_size=(64, 64), batch_size=64, shuffle=True)
normalization_layer = tf.keras.layers.Rescaling(1./255)
normalized_ds = dataset.map(lambda x, y: (normalization_layer(x), y))
AUTOTUNE = tf.data.AUTOTUNE
train_ds = normalized_ds.cache().prefetch(buffer_size=AUTOTUNE)
VQVAE code:
class VectorQuantizer(layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.beta = (beta)
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(initial_value=w_init(shape=(self.embedding_dim, self.num_embeddings), dtype="float32"),
trainable=True,
name="embeddings_vqvae",)
def call(self, x):
input_shape = tf.shape(x)
flattened = tf.reshape(x, [-1, self.embedding_dim])
# Quantization.
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)
commitment_loss = self.beta * tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
self.add_loss(commitment_loss codebook_loss)
# Straight-through estimator.
quantized = x tf.stop_gradient(quantized - x)
return quantized
def get_code_indices(self, flattened_inputs):
# Calculate L2-normalized distance between the inputs and the codes.
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
tf.reduce_sum(self.embeddings ** 2, axis=0)
- 2 * similarity
)
# Derive the indices for minimum distances.
encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices
def get_vqvae(latent_dim=16, num_embeddings=64):
vq_layer = VectorQuantizer(num_embeddings, latent_dim,
name="vector_quantizer")
encoder = get_encoder(latent_dim)
decoder = get_decoder(latent_dim)
inputs = keras.Input(shape=(64, 64, 3))
encoder_outputs = encoder(inputs)
quantized_latents = vq_layer(encoder_outputs)
reconstructions = decoder(quantized_latents)
return keras.Model(inputs, reconstructions, name="vq_vae")
class VQVAETrainer(keras.models.Model):
def __init__(self, train_variance, latent_dim=32, num_embeddings=128, **kwargs):
super(VQVAETrainer, self).__init__(**kwargs)
self.train_variance = train_variance
self.latent_dim = latent_dim
self.num_embeddings = num_embeddings
self.vqvae = get_vqvae(self.latent_dim, self.num_embeddings)
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.vq_loss_tracker,]
def train_step(self, x):
with tf.GradientTape() as tape:
# Outputs from the VQ-VAE.
reconstructions = self.vqvae(x)
# Calculate the losses.
reconstruction_loss = (
tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
)
total_loss = reconstruction_loss sum(self.vqvae.losses)
# Backpropagation.
grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
# Loss tracking.
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
# Log results.
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"vqvae_loss": self.vq_loss_tracker.result(),
}
get_vqvae().summary()
Encoder and Decoder (I have made changes here but I dont think this is the problem):
def get_encoder(latent_dim=16):
encoder_inputs = keras.Input(shape=(64, 64, 3))
x = layers.Conv2D(64, 3, padding="same")(encoder_inputs)
x = layers.Dropout(0.25)(x)
x = layers.Activation("relu")(x)
x = layers.UpSampling2D()(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Conv2D(64, 3, strides=2, padding="same")(x)
x = layers.Dropout(0.25)(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(256, 3, strides=2, padding="same")(x)
x = layers.Activation("relu")(x)
x = layers.UpSampling2D()(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Dense(4*4*128)(x)
encoder_outputs = layers.Conv2D(latent_dim, 1, padding="same")(x)
return keras.Model(encoder_inputs, encoder_outputs, name="encoder")
get_encoder().summary()
def get_decoder(latent_dim=16):
latent_inputs = keras.Input(shape=get_encoder().output.shape[1:])
x = layers.Conv2DTranspose(32, 3, padding="same")(latent_inputs)
x = layers.Dropout(0.25)(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Conv2DTranspose(32, 3, padding="same")(latent_inputs)
x = layers.Dropout(0.25)(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Conv2DTranspose(64, 3, padding="same")(x)
x = layers.Dropout(0.25)(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Conv2DTranspose(128, 3, strides=2, padding="same")(x)
x = layers.Dropout(0.25)(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Conv2DTranspose(256, 3, padding="same")(x)
x = layers.Dropout(0.25)(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Conv2DTranspose(512, 3, strides=2, padding="same")(x)
x = layers.Dropout(0.25)(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization(momentum=0.8)(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, padding="same")(x)
return keras.Model(latent_inputs, decoder_outputs, name="decoder")
get_decoder().summary()
I get the error here:
vqvae_trainer = VQVAETrainer(data_variance, latent_dim=16,
num_embeddings=128)
vqvae_trainer.compile(optimizer=tf.keras.optimizers.Adam())
vqvae_trainer.fit(train_ds, epochs=30, batch_size=128)
CodePudding user response:
This kind of model does not work with labels. Try running:
normalized_ds = dataset.map(lambda x, y: normalization_layer(x))
to discard the labels, since you are actually only interested in x
.