I have written some lines of code to extract 5 bounding boxes/patches of a single image. When I run this code and print its output shape it's something like this (5, 256)
, five patches each with a vector of 256. The problem is the separate extraction of patches from a single image, when I feed 5000 images to this code, it generates 5000*5
patches mixed with each other. In this way, it loses the patch/image relationship. I want to change this code somehow to generate an output with batch information like this (1, 5, 256)
. In this way, each batch will represent an image.
def create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):
xception = keras.applications.Xception(
include_top=False, weights="imagenet", pooling="avg"
)
for layer in xception.layers:
layer.trainable = trainable
inputs = layers.Input(shape=(299, 299, 3), name="image_input")
NUM_BOXES = 5
CHANNELS = 3
CROP_SIZE = (200, 200)
boxes = tf.random.uniform(shape=(NUM_BOXES, 4))
box_indices = tf.random.uniform(shape=(NUM_BOXES,), minval=0,
maxval=BATCH_SIZE, dtype=tf.int32)
output = tf.image.crop_and_resize(inputs, boxes, box_indices, CROP_SIZE)
xception_input = tf.keras.applications.xception.preprocess_input(output)
embeddings = xception(xception_input)
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
)
return keras.Model(inputs, outputs, name="vision_encoder")
CodePudding user response:
You can create a ImagePatchesAndEmbedding
layer that will stack the captured bounding boxes and apply xception
:
class ImagePatchesAndEmbedding(keras.layers.Layer):
def __init__(self, crop_size, num_boxes=5, minval=0, maxval=1):
super(ImagePatchesAndEmbedding, self).__init__()
self.crop_size = crop_size
self.boxes = tf.random.uniform(shape=(num_boxes, 4))
self.box_indices = tf.random.uniform(shape=(num_boxes,), minval=0,
maxval=1, dtype=tf.int32)
self.preprocess = tf.keras.applications.xception.preprocess_input
def call(self, inputs):
patches = tf.map_fn(lambda img:tf.image.crop_and_resize(img[None,...],
self.boxes, self.box_indices, self.crop_size), inputs)
embeddings = tf.map_fn(lambda patch: xception(self.preprocess(patch)), patches)
return embeddings
Model,
inputs = layers.Input(shape=(299, 299, 3), name="image_input")
NUM_BOXES = 5
CHANNELS = 3
CROP_SIZE = (200, 200)
BATCH_SIZE = 3
output = ImagePatchesAndEmbedding(CROP_SIZE, num_boxes=5, maxval=BATCH_SIZE)(inputs)
model = keras.Model(inputs, output)
Call model,
model(tf.random.normal(shape=(BATCH_SIZE, 299, 299, 3))).shape
#[3, 5, 2048]