Home > Mobile >  How to modify image in custom Tensorflow layer? (working example provided)
How to modify image in custom Tensorflow layer? (working example provided)

Time:06-27

How can I draw a filled rectangle as a custom (data augmentation) layer in Tensorflow 2 on Python 3?

Input Expected output
input expected output

With image_pil = Image.fromarray(image), I get the error:

AttributeError: Exception encountered when calling layer "remove_patch_5" (type RemovePatch).

'tensorflow.python.framework.ops.EagerTensor' object has no attribute '__array_interface__'

Call arguments received by layer "remove_patch_5" (type RemovePatch):
  • image=tf.Tensor(shape=(300, 300, 3), dtype=uint8)
  • training=True

Example

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from PIL import Image, ImageDraw


class RemovePatch(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, image, training=None):
        if not training:
            return image

        # This is the part that doesn't work
        # image_pil = Image.fromarray(image)
        """
        AttributeError: Exception encountered when calling layer "remove_patch_5" (type RemovePatch).

        'tensorflow.python.framework.ops.EagerTensor' object has no attribute '__array_interface__'

        Call arguments received by layer "remove_patch_5" (type RemovePatch):
          • image=tf.Tensor(shape=(200, 200, 3), dtype=uint8)
          • training=True
        """
        # image = np.array(
        #     ImageDraw.Draw(image_pil).rectangle(
        #         [50, 50, 100, 100], fill="#000000"
        #     )
        # )

        # This part works for adjusting brightness,
        # but no built-in function for drawing a
        # rectangle was found
        # image = tf.image.adjust_brightness(image, 0.5)

        return image


layer = RemovePatch()

image_file = "image.jpg"

try:
    open(image_file)
except FileNotFoundError:
    from requests import get

    r = get("https://picsum.photos/seed/picsum/300/300")
    with open(image_file, "wb") as f:
        f.write(r.content)

with Image.open(image_file) as img:
    img = np.array(img)
    augmented = layer(img, training=True)

    augmented = np.array(augmented)

    # plt.imshow(img)
    plt.imshow(augmented)

show_expected = False
if show_expected:
    with Image.open(image_file) as img:
        ImageDraw.Draw(img).rectangle([50, 50, 100, 100], fill="#000000")

        plt.imshow(img)

CodePudding user response:

A working solution with enter image description here

  • Related