Home > Net >  Stitching patches of Image together
Stitching patches of Image together

Time:04-05

Hi I have a batch of Images and I need to divide it to non-overlapping patches and send each patch through the softmax function and then reconstruct the original images. I can make the patches doing as follow:

@tf.function
def grid_img(img,patch_size=(256, 256), padding="VALID"):
    p_height, p_width = patch_size
    batch_size, height, width, n_filters = img.shape
    p = tf.image.extract_patches(images=img,
                       sizes=[1,p_height, p_width, 1],
                       strides=[1,p_height, p_width, 1],
                       rates=[1, 1, 1, 1],
                       padding=padding)
    new_shape = list(p.shape[1:-1]) [p_height, p_width, n_filters]
    p = tf.keras.layers.Reshape(new_shape)(p)
    return p

But I can't figure out how to reconstruct the original image in batches. Simple reshaping to the original batch doesn't work. The data would not be arranged in the right way. I would appreciate any help. thanks

CodePudding user response:

IIUC, you should be able to simply use tf.reshape to reconstruct the original images from batches of patches:

import tensorflow as tf

samples = 5
images = tf.random.normal((samples, 256, 256, 3))

@tf.function
def grid(images):
  img_shape = tf.shape(images)
  batch_size, height, width, n_filters = img_shape[0], img_shape[1], img_shape[2], img_shape[3]

  patches = tf.image.extract_patches(images=images,
                                      sizes=[1, 32, 32, 1],
                                      strides=[1, 32, 32, 1],
                                      rates=[1, 1, 1, 1],
                                      padding='VALID')
  return tf.reshape(tf.nn.softmax(patches), (batch_size, height, width, n_filters))
  
patches = grid(images)
print(patches.shape)
# (5, 256, 256, 3)

Update 1: If you want to reconstruct the images in the correct order, you can calculate the gradients of tf.image.extract_patches as shown in this code enter image description here

CodePudding user response:

A dirty work around this I thought of is to track the location of the cells after the transformation. Not as elegant as @alonetogether Answer but still might be helpful to share.

import numpy as np 
import tensorflow as tf

@tf.function
def grid(images, grid_size=(32, 32)):
    grid_height, grid_width = grid_size
    patches = tf.image.extract_patches(images=images,
                                      sizes=[1, grid_height, grid_width, 1],
                                      strides=[1, grid_height, grid_width, 1],
                                      rates=[1, 1, 1, 1],
                                      padding='VALID')
    return patches

batch_size, height, width, n_filters = shape = (5, 256, 256, 1)
indices = tf.range(batch_size * height * width * n_filters)
images = tf.reshape(indices, (batch_size, height, width, n_filters ))

patches = grid(images)
transfered_indices = tf.reshape(patches, shape=[-1])
tracked_indices = tf.argsort(transfered_indices) # Indices after transformation, Save this 


images = tf.random.normal(shape)

patches = grid(images)

flatten_patches = tf.reshape(patches, shape=[-1])

reconstructed = tf.reshape(tf.gather(flatten_patches, tracked_indices), shape)

np.alltrue(reconstructed==images) # True
  • Related