I have a 2D array data
and a boolean array mask
of shapes (M,N)
. I need to randomly pick an element in each row of data
. However, the element I picked should be true in the given mask. Is there a way to do this without looping over every row? In every row, there are at least 2 elements for which mask
is true.
Minimum Working Example:
data = numpy.arange(8).reshape((2,4))
mask = numpy.array([[True, True, True, True], [True, True, False, False]])
selected_data = numpy.random.choice(data, mask, num_elements=1, axis=1)
The 3rd line above doesn't work. I want something like that. I've listed below some valid solutions.
selected_data = [0,4]
selected_data = [1,5]
selected_data = [2,5]
selected_data = [3,4]
CodePudding user response:
It is easier to work with the indices of the mask. We can get the indices of the True values from the mask and stack them together to create 2D coordinates array. All of the values inside the indices2d are possible to sample. Then we can shuffle the array and get the first index of the unique row values. Since the array is shuffled, it is random choice. Then we can match the selected 2D indices to the original data. See below;
import numpy
data = numpy.arange(8).reshape((2,4))
mask = numpy.array([[True, True, True, True], [True, True, False, False]])
for _ in range(20):
indices2d = numpy.dstack(numpy.where(mask)).squeeze().astype(numpy.int32)
numpy.random.shuffle(indices2d)
randomElements = indices2d[numpy.unique(indices2d[:, 0], return_index=True)[1]]
print(data[randomElements[:,0],randomElements[:,1]])
Output
[0 5]
[1 4]
[0 5]
[1 4]
[1 5]
[0 5]
[1 4]
[1 5]
[3 4]
[2 5]
[2 4]
[3 5]
[2 4]
[0 4]
[0 4]
[0 4]
[0 5]
[3 5]
[3 5]
[1 4]
12.7 ms ± 80.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
CodePudding user response:
The only solution I can imagine is:
[np.random.choice(data[i], p=mask[i]/mask[i].sum()) for i in range(data.shape[0])]
#60 µs ± 2.37 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
You can speed it up with:
[np.random.choice(data[i][mask[i]]) for i in range(data.shape[0])]
#29.9 µs ± 4.77 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)