Home > Back-end >  ValueError: Layer "vq_vae" expects 1 input(s), but it received 2 input tensors on a VQVAE
ValueError: Layer "vq_vae" expects 1 input(s), but it received 2 input tensors on a VQVAE

Time:03-22

I am training a VQVAE with this dataset (64x64x3). I have downloaded it locally and loaded it with keras in Jupyter notebook. The problem is that when I ran fit() to train the model I get this error: ValueError: Layer "vq_vae" expects 1 input(s), but it received 2 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 64, 64, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int32>] . I have taken most of the code from here and adapted it myself. But for some reason I can't make it work for other datasets. You can ignore most of the code here and check it in the page, help is much appreciated.

The code I have so far:

img_height = 64
img_width = 64
dataset = tf.keras.utils.image_dataset_from_directory(directory="PATH", 
image_size=(64, 64), batch_size=64, shuffle=True)
normalization_layer = tf.keras.layers.Rescaling(1./255)
normalized_ds = dataset.map(lambda x, y: (normalization_layer(x), y))

AUTOTUNE = tf.data.AUTOTUNE

train_ds = normalized_ds.cache().prefetch(buffer_size=AUTOTUNE)

VQVAE code:

class VectorQuantizer(layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = (beta)

    
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(initial_value=w_init(shape=(self.embedding_dim, self.num_embeddings), dtype="float32"),
        trainable=True,
        name="embeddings_vqvae",)

    def call(self, x):
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # Quantization.
        encoding_indices = self.get_code_indices(flattened)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
        quantized = tf.reshape(quantized, input_shape)

    
        commitment_loss = self.beta * tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(commitment_loss   codebook_loss)

        # Straight-through estimator.
        quantized = x   tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate L2-normalized distance between the inputs and the codes.
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        distances = (
        tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
          tf.reduce_sum(self.embeddings ** 2, axis=0)
        - 2 * similarity
        )

        # Derive the indices for minimum distances.
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices

    def get_vqvae(latent_dim=16, num_embeddings=64):
       vq_layer = VectorQuantizer(num_embeddings, latent_dim, 
       name="vector_quantizer")
       encoder = get_encoder(latent_dim)
       decoder = get_decoder(latent_dim)
       inputs = keras.Input(shape=(64, 64, 3))
       encoder_outputs = encoder(inputs)
       quantized_latents = vq_layer(encoder_outputs)
       reconstructions = decoder(quantized_latents)
       return keras.Model(inputs, reconstructions, name="vq_vae")

class VQVAETrainer(keras.models.Model):
    def __init__(self, train_variance, latent_dim=32, num_embeddings=128, **kwargs):
        super(VQVAETrainer, self).__init__(**kwargs)
        self.train_variance = train_variance
        self.latent_dim = latent_dim
        self.num_embeddings = num_embeddings

        self.vqvae = get_vqvae(self.latent_dim, self.num_embeddings)

        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
        name="reconstruction_loss"
        )
        self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")

        @property
        def metrics(self):
            return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.vq_loss_tracker,]

        def train_step(self, x):
           with tf.GradientTape() as tape:
           # Outputs from the VQ-VAE.
           reconstructions = self.vqvae(x)

           # Calculate the losses.
            reconstruction_loss = (
            tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
        )
        total_loss = reconstruction_loss   sum(self.vqvae.losses)

           # Backpropagation.
           grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
           self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))

         # Loss tracking.
         self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.vq_loss_tracker.update_state(sum(self.vqvae.losses))

       # Log results.
        return {
         "loss": self.total_loss_tracker.result(),
          "reconstruction_loss": self.reconstruction_loss_tracker.result(),
          "vqvae_loss": self.vq_loss_tracker.result(),
    }

     get_vqvae().summary()

Encoder and Decoder (I have made changes here but I dont think this is the problem):

def get_encoder(latent_dim=16):
    encoder_inputs = keras.Input(shape=(64, 64, 3))
    x = layers.Conv2D(64, 3, padding="same")(encoder_inputs)
    x = layers.Dropout(0.25)(x)
    x = layers.Activation("relu")(x)
    x = layers.UpSampling2D()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Conv2D(64, 3, strides=2, padding="same")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Activation("relu")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(256, 3, strides=2, padding="same")(x)
    x = layers.Activation("relu")(x)
    x = layers.UpSampling2D()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Dense(4*4*128)(x)
    encoder_outputs = layers.Conv2D(latent_dim, 1, padding="same")(x)
    return keras.Model(encoder_inputs, encoder_outputs, name="encoder")

get_encoder().summary()


def get_decoder(latent_dim=16):
    latent_inputs = keras.Input(shape=get_encoder().output.shape[1:])
    x = layers.Conv2DTranspose(32, 3, padding="same")(latent_inputs)
    x = layers.Dropout(0.25)(x)
    x = layers.Activation("relu")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Conv2DTranspose(32, 3, padding="same")(latent_inputs)
    x = layers.Dropout(0.25)(x)
    x = layers.Activation("relu")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Conv2DTranspose(64, 3, padding="same")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Activation("relu")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Conv2DTranspose(128, 3, strides=2, padding="same")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Activation("relu")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Conv2DTranspose(256, 3, padding="same")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Activation("relu")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Conv2DTranspose(512, 3, strides=2, padding="same")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Activation("relu")(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    decoder_outputs = layers.Conv2DTranspose(1, 3, padding="same")(x)
    return keras.Model(latent_inputs, decoder_outputs, name="decoder")

get_decoder().summary()

I get the error here:

vqvae_trainer = VQVAETrainer(data_variance, latent_dim=16, 
num_embeddings=128)
vqvae_trainer.compile(optimizer=tf.keras.optimizers.Adam())
vqvae_trainer.fit(train_ds, epochs=30, batch_size=128)

CodePudding user response:

This kind of model does not work with labels. Try running:

normalized_ds = dataset.map(lambda x, y: normalization_layer(x))

to discard the labels, since you are actually only interested in x.

  • Related