I have an array of RGBA values that looks something like this:
# Not all elements are [0, 0, 0, 0]
array([[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
...,
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]])
I also have a function which returns one of 5 values that a certain RGBA value is closest to (green, red, orange, brown, white).
def closest_colour(requested_colour):
min_colours = {}
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
if name in ['green', 'red', 'orange', 'brown', 'white']:
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
rd = (r_c - requested_colour[0]) ** 2
gd = (g_c - requested_colour[1]) ** 2
bd = (b_c - requested_colour[2]) ** 2
min_colours[(rd gd bd)] = name
return min_colours[min(min_colours.keys())]
I'd like to apply this function to each element of my numpy array and change those elements. I tried doing it this way:
img_array[closest_colour(img_array) == 'green'] = (0, 255, 0, 1)
img_array[closest_colour(img_array) == 'red'] = (255, 0, 0, 1)
img_array[closest_colour(img_array) == 'brown'] = (92, 64, 51, 1)
img_array[closest_colour(img_array) == 'orange'] = (255, 165, 0, 1)
img_array[closest_colour(img_array) == 'white'] = (255, 255, 255, 0)
but I get an error:
TypeError: unhashable type: 'numpy.ndarray'
I am aware of why this error occurs but I also don't know a different way to do this efficiently.
Is there a way to do this efficiently as I'm working with a fairly large array (image)?
CodePudding user response:
You can use numpy.apply_along_axis
:
np.apply_along_axis(closest_colour, axis=1, arr=img_array)
If you would just like to replace these values with new ones, let your function return these new values.
CodePudding user response:
I would rewrite your function to be a bit more vectorized. First, you really don't need to loop through the entire dictionary of CSS colors for every pixel: the lookup table can be trivially precomputed. Second, you can map the five colors you want to RGBA values without using the names as an intermediary. This will make your life much easier since you'll be working with numbers instead of strings most of the time.
names = dict.fromkeys(['green', 'red', 'orange', 'brown', 'white'])
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
if name in names:
names[name] = key
lookup = np.array([webcolors.hex_to_rgb(key) (1,) for key in names.values()])
Since the number of colors is small, you can compute an Nx5 array of distances to the colors:
distance = ((rgba[..., None, :] - lookup)**2).sum(axis=-1)
If you don't want to include the transparency in the distance, remove it from the comparison:
distance = ((rgba[..., None, :3] - lookup[..., :3])**2).sum(axis=-1)
This gives you an Nx5 array of distances (where N can be more than one dimension, because of the intentional use of ...
instead of :
). The minima are at
closest = distance.argmin(-1)
Now you can apply this index directly to the lookup table:
result = lookup[closest]
Here is a sample run:
>>> np.random.seed(42)
>>> rgba = np.random.randint(255, size=(10, 4))
>>> rgba
array([[102, 179, 92, 14],
[106, 71, 188, 20],
[102, 121, 210, 214],
[ 74, 202, 87, 116],
[ 99, 103, 151, 130],
[149, 52, 1, 87],
[235, 157, 37, 129],
[191, 187, 20, 160],
[203, 57, 21, 252],
[235, 88, 48, 218]])
>>> lookup = np.array([
... [0, 255, 0, 1],
... [255, 0, 0, 1],
... [92, 64, 51, 1],
... [255, 165, 0, 1],
... [255, 255, 255, 0]], dtype=np.uint8)
>>> distance = ((rgba[..., None, :3] - lookup[..., :3])**2).sum(axis=-1)
>>> distance
array([[ 24644, 63914, 15006, 32069, 55754],
[ 80436, 62586, 19014, 66381, 60546],
[ 72460, 82150, 28630, 69445, 43390],
[ 15854, 81134, 20664, 41699, 63794],
[ 55706, 57746, 11570, 50981, 58256],
[ 63411, 13941, 5893, 24006, 116961],
[ 66198, 26418, 29294, 1833, 57528],
[ 41505, 39465, 25891, 4980, 63945],
[ 80854, 6394, 13270, 14809, 96664],
[ 85418, 10448, 21034, 8633, 71138]])
>>> closest = distance.argmin(-1)
>>> closest
array([2, 2, 2, 0, 2, 2, 3, 3, 1, 3])
>>> lookup[closest]
array([[ 92, 64, 51, 1],
[ 92, 64, 51, 1],
[ 92, 64, 51, 1],
[ 0, 255, 0, 1],
[ 92, 64, 51, 1],
[ 92, 64, 51, 1],
[255, 165, 0, 1],
[255, 165, 0, 1],
[255, 0, 0, 1],
[255, 165, 0, 1]], dtype=uint8)