Home > Enterprise >  How to map colors in an image to a palette fast with numpy?
How to map colors in an image to a palette fast with numpy?

Time:10-25

I have two arrays. One is an image array and the other is a palette array. Both have elements containing 8-bit RGB channels. I need to replace every color in the image with the closest color in the palette.

Currently I'm measuring distance in the RGB-space, which is not ideal, but easy to implement.

This is my implementation:

image_array = np.array(image) # converts PIL image, values are uint8
# palette values are also 8-bit but I use int so I don't have to cast types
palette_array = np.array(palette, dtype=[('red', np.int), ('green', np.int), ('blue', np.int)])
mapped_image = np.empty((image_height, image_width, 3), dtype=np.uint8)
for x in range(image_width):
    for y in range(image_height):
        r, g, b = image_array[y, x]
        distances_squared = (r-palette['red'])**2   (g-palette['green'])**2   (b-palette['blue'])**2
        closest_index = np.argmin(distances_squared)
        closest_color = palette.flat[closest_index]
        mapped_image[y, x] = closest_color

The palette has 4096 random colors (simple conversion is not possible). When mapping a 600x448 sized image this takes roughly a minute even on my core i5 machine. I plan to use this on lower-end devices like a raspberry pi, where it takes roughly 3 minutes to map a small image.

This is way too slow. I believe this can be sped up significantly when the full loop is implemented with numpy syntax, but I can't figure out how to do this.

How do I get from the original image to the mapped one all implemented with numpy syntax?

CodePudding user response:

You can try using cKDTree function from scipy.

import numpy as np
from scipy.spatial import cKDTree
palette=np.random.randint(0, 255, size=(4096,3), dtype=np.uint8) # random palette
image_in=np.random.randint(0, 255, size=(800, 600, 3), dtype=np.uint8) # random image
size=image_in.shape
vor=cKDTree(palette)
test_points=np.reshape(image_in, (-1,3))
_, test_point_regions = vor.query(test_points, k=1)
image_out=palette[test_point_regions]
np.reshape(image_out, size)

This program runs for approximately 0.8 seconds.

  • Related