Home > Mobile >  Mtplotlib plot becomes blank after tf.image.resize
Mtplotlib plot becomes blank after tf.image.resize

Time:10-02

I have some code that I am using with tensorflow datasets. It's worked fine previously and it may still work. But I don't think so

img = parse_image(img_paths[0])
img = tf.image.resize(img, [224, 224])
plt.imshow(img)

Just outputs a blank 224x224 canvas.

img = parse_image(img_paths[0])
plt.imshow(img)

outputs the image correctly.

img_paths is a list of strings with pathnames

I have tried:

img = parse_image(img_paths[0])
img = tf.image.resize([img], [224, 224])
plt.imshow(img[0])

and

img = parse_image(img_paths[0])
img = tf.image.resize(img, [224, 224])
plt.imshow(img.numpy())

and

img = parse_image(img_paths[0])
img = tf.image.resize([img], [224, 224])
plt.imshow(img.numpy()[0])

The shape is correct and this code has worked before. And may still work, I'm thinking I may not use it correctly anymore (been a while since I wrote it).

thanks for any hints or thoughts you can provide? And of course solutions ;-)

CodePudding user response:

The "problem" is with Matplotlib. When you resize with Tensorflow, it turns your input to float. Matplotlib accepts two image formats, integers between 0-255 and floats between 0 and 1. If you call plt.imshow() on floats of more than 1, it will clip all values and you'll see a white image. I suspect this is what you're getting.

tf.image.convert_image_dtype has a saturate argument, and its default value makes it that the 0-255 integer range becomes 0-1 float. This is why it "works", because Matplotlib understands that format. After this, the Tensorflow resizing operation keeps it between 0-1 too, so it works.

CodePudding user response:

Huh,

I saw something elsewhere and added this line:

img = tf.image.convert_image_dtype(img, tf.float32)

before resizing and it worked.

This is extremely weird because I didn't need this line before. Maybe due to a version update?

Either way this works:

img = parse_image(train_img_paths[0])
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, [224, 224])
plt.imshow(img)
  • Related