Home > database >  Convert an image from RGB to index in palette using Tensorflow
Convert an image from RGB to index in palette using Tensorflow

Time:04-22

I want to convert an RGB image to one with a single channel, whose value is an integer index from a palette (which has already been extracted).

An example:

import tensorflow as tf

# image shape (height=2, width=2, channels=3)
image = tf.constant([
  [
    [1., 1., 1.], [1., 0., 0.]
  ],
  [
    [0., 0., 1.], [1., 0., 0.]
  ]
])

# palette is a tensor with the extracted colors
# palette shape (num_colors_in_palette, 3) 
palette = tf.constant([
  [1., 0., 0.],
  [0., 0., 1.],
  [1., 1., 1.]
])

indexed_image = rgb_to_indexed(image, palette)
# desired result: [[2, 0], [1, 0]]
# result shape (height, width)

I can imagine a few ways to implement rgb_to_indexed(image, palette) in pure python, but I'm having trouble finding out how to implement it the Tensorflow way (using @tf.funtion for AutoGraph and avoiding for loops), using only (or mostly) vectorized operations.

Edit 1: showing sample python/numpy code

If the code need not use Tensorflow, a non-vectorized implementation could be:

import numpy as np

def rgb_to_indexed(image, palette):
    result = np.ndarray(shape=[image.shape[0], image.shape[1]])

    for i, row in enumerate(image):
        for j, color in enumerate(row):
            index, = np.where(np.all(palette == color, axis=1))
            result[i, j] = index
    return result

indexed_image = rgb_to_indexed(image.numpy(), palette.numpy())
# indexed_image is [[2, 0], [1, 0]]

CodePudding user response:

I used the technique described in this other question (Find indices of rows of numpy 2d array in another 2D array) and adapted it from Numpy to Tensorflow. It is fully vectorized and executes very fast.

First in Numpy (vectorized):

def rgb_to_indexed(image, palette):
    original_shape = image.shape

    # flattens the image to the shape (height*width, channels)
    flattened_image = image.reshape(original_shape[0]*original_shape[1], -1)
    num_pixels, num_channels = flattened_image.shape[0], flattened_image.shape[1]

    # creates a mask of pixel and color matches and reduces it to two lists of indices:
    # a) color in the palette, and b) pixel in the image
    indices = flattened_image == palette[:, None]
    row_sums = indices.sum(axis=2)
    color_indices, pixel_indices = np.where(row_sums == num_channels)

    # sets -42 as the default value in case some color is not in the palette,
    # then replaces the values for which some index has been found in the palette
    INDEX_OF_COLOR_NOT_FOUND = -42
    indexed_image = np.ones(num_pixels, dtype="int64") * -1
    indexed_image[pixel_indices] = color_indices

    # reshapes to "deflatten" the indexed_image and give it a single channel (the index)
    indexed_image = indexed_image.reshape([*original_shape[0:2]])

    return indexed_image

Then my translation to Tensorflow:

@tf.function
def rgba_to_indexed(image, palette):
    original_shape = tf.shape(image)

    # flattens the image to have (height*width, channels)
    # so it has the same rank as the palette
    flattened_image = tf.reshape(image, [original_shape[0]*original_shape[1], -1])
    num_pixels, num_channels = tf.shape(flattened_image)[0], tf.shape(flattened_image)[1]

    # does the mask magic but using tensorflow ops
    indices = flattened_image == palette[:, None]
    row_sums = tf.reduce_sum(tf.cast(indices, "int32"), axis=2)
    results = tf.cast(tf.where(row_sums == num_channels), "int32")

    color_indices, pixel_indices = results[:, 0], results[:, 1]
    pixel_indices = tf.expand_dims(pixel_indices, -1)

    # fills with default value then updates the palette color indices of the pixels
    # with colors present in the palette
    INDEX_OF_COLOR_NOT_FOUND = -42
    indexed_image = tf.fill([num_pixels], INDEX_OF_COLOR_NOT_FOUND)
    indexed_image = tf.tensor_scatter_nd_add(
        indexed_image,
        pixel_indices,
        color_indices - INDEX_OF_COLOR_NOT_FOUND,
        tf.shape(indexed_image))
    
    # reshapes the image back to (height, width)
    indexed_image = tf.reshape(indexed_image, [original_shape[0], original_shape[1]])

    return indexed_image
  • Related