Home > Back-end >  Is there a numpy magic avoiding these loops?
Is there a numpy magic avoiding these loops?

Time:03-10

I would like to avoid for loops in this code snippet:

import numpy as np

N = 4
a = np.random.randint(0, 256, size=(N, N, 3))
m = np.random.randint(0, 2, size=(N, N))

for i, d0 in enumerate(a):
  for j, d1 in enumerate(d0):
    if m[i, j]:
      d1[2] = 42

This is a simplified example where a is an N x N RGB image and m is a N x N mask, which sets masked elements of the 3rd channel: a[:, :, 2] only.

CodePudding user response:

You can index the last axis and set the elements selected by a boolean mask

import numpy as np

N = 4
a = np.random.randint(0, 256, size=(N, N, 3))
m = np.random.randint(0, 2, size=(N, N))

a[...,2][m.astype('bool')] = 42
a

Output (for a random example of a)

array([[[ 9, 13,  4],
        [15,  0, 42],
        [11, 12,  9],
        [13,  0, 42]],

       [[ 1, 10, 42],
        [ 9,  0, 42],
        [ 8,  6,  4],
        [ 3,  0, 42]],

       [[15, 11,  6],
        [ 8, 11, 42],
        [14,  1, 42],
        [ 4, 14,  1]],

       [[ 3,  6, 42],
        [ 5, 13,  3],
        [ 9, 14, 13],
        [12,  6, 42]]])

CodePudding user response:

The following worked for me.

a[:,:,2] *= (1-m)
a[:,:,2]  = m*42
  • Related