Say I have an array:
x = np.array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
And a multi-labeled mask:
labels = np.array([[0, 0, 2],
[1, 1, 2],
[1, 1, 2]])
My goal is to sum the entries of x
together, grouped by labels
. For example:
n_labels = np.max(labels) 1
out = np.empty(n_labels)
for label in range(n_labels):
mask = labels == label
out[label] = np.sum(x[mask])
>>> out
np.array([1, 20, 15])
However, as x
and n_labels
become large, I see this being inefficient. Each iteration, you are only summing together a small fraction of the number of entries of x
, but still recompute the mask over all of labels
(in the expression labels == label
) and subsequently index over all of x
(in the expression x[mask]
). Is there a more efficient way to do this as x
and n_labels
grow large?
CodePudding user response:
You can use bincount
with weights:
np.bincount(labels.ravel(), weights=x.ravel())
out:
array([ 1., 20., 15.])
CodePudding user response:
You don't really have a reason to operate on 2D arrays, so ravel them first:
labels = labels.ravel()
x = x.ravel()
If your labels are already indices, you can use np.argsort
along with np.diff
and np.add.reduceat
:
index = labels.argsort()
splits = np.r_[0, np.flatnonzero(np.diff(labels[index])) 1]
result = np.add.reduceat(x[index], splits)
labels[index]
is the sorted index. Whenever that changes, you enter a new group, and the diff
is nonzero. That's what np.flatnonzero(np.diff(labels[index]))
finds for you. Since reduceat
takes the stop index past the end of the run, you need to add one. np.r_
allows you to prepend zero easily to a 1D array, which is necessary for reduceat
to regard t, and also prepend zero., and also prepend zero.he first run (the last is always automatic).
Before you run reduceat
, you need to order x
into the runs defined by labels
, which is what x[index]
does.
CodePudding user response:
You can use 2D arrays with another slow and over-engineered approach using np.add.at
sums = np.zeros(labels.max() 1, x.dtype)
np.add.at(sums, labels, x)
sums
Output
array([ 1, 20, 15])