I have to apply tf.image.crop_and_resize
on my images and want to generate 5 boxes from each image. I have written the below code which works fine
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
# Load the pre-trained Xception model to be used as the base encoder.
xception = keras.applications.Xception(
include_top=False, weights="imagenet", pooling="avg"
)
# Set the trainability of the base encoder.
for layer in xception.layers:
layer.trainable = False
# Receive the images as inputs.
inputs = layers.Input(shape=(299, 299, 3), name="image_input")
input ='/content/1.png'
input = tf.keras.preprocessing.image.load_img(input,target_size=(299,299,3))
image = tf.expand_dims(np.asarray(input)/255, axis=0)
BATCH_SIZE = 1
NUM_BOXES = 5
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
CHANNELS = 3
CROP_SIZE = (24, 24)
boxes = tf.random.uniform(shape=(NUM_BOXES, 4))
box_indices = tf.random.uniform(shape=(NUM_BOXES,), minval=0, maxval=BATCH_SIZE, dtype=tf.int32)
output = tf.image.crop_and_resize(image, boxes, box_indices, CROP_SIZE)
xception_input = tf.keras.applications.xception.preprocess_input(output)
The above code works fine however when I want to display these boxes I run below code
for i in range(5):
# define subplot
plt.subplot(330 1 i)
# generate batch of images
batch = xception_input.next()
# convert to unsigned integers for viewing
image = batch[0].astype('uint8')
image = np.reshape(24,24,3)
# plot raw pixel data
plt.imshow(image)
#show the figure
plt.show()
But it generates this error AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'next'
.
CodePudding user response:
You have to use [i]
instead of .next()
And there is also problem with converting it to uint8
(but it doesn't need to reshape
)
for i in range(5):
plt.subplot(331 i)
tensor = xception_input[i]
#print(tensor)
tensor = tensor*255
image = np.array(tensor, dtype=np.uint8)
#print(image)
plt.imshow(image)
or use for
to get items
for i, tensor in enumerate(xception_input):
#print(tensor)
plt.subplot(331 i)
tensor = tensor*255
image = np.array(tensor, dtype=np.uint8)
#print(image)
plt.imshow(image)
I don't know what your code should do but this gives me empty images because tensor
has values like -0.9
and it convert it all to 0