I have this for loop that I need to vectorize. The code below works, but takes a lot of time (this is a simplified example, the full version will have about 1e6 elements in col_ids). Can someone give me an idea how to vectorize this code to get rid of the loop? If it matters, the col_ids
are fixed (will be the same every time the code is run), while the values
will change.
values = np.array([1.5, 2, 2.3])
col_ids = np.array([[0,0,0,0], [0,0,0,1], [0,0,1,1]])
result = np.zeros((4,3))
for idx, col_idx in enumerate(col_ids):
result[np.arange(4),col_idx] = values[idx]
Result:
[[5.8 0. 0. ]
[5.8 0. 0. ]
[3.5 2.3 0. ]
[1.5 4.3 0. ]]
CodePudding user response:
You can solve this using np.add.at
. However, AFAIK, this function does not support 2D array so you need to flatten the arrays, computing the 1D flatten indices, and then call the function:
n, m = result.shape
result = np.zeros((4,3))
indices = np.tile(np.arange(0, n*m, m), col_ids.shape[0]) col_ids.ravel()
np.add.at(result.ravel(), indices, np.repeat(values, n)) # In-place
print(result)
CodePudding user response:
You can vectorize the loop, but creating an additional intermediate array is much slower for larger data (starting from result with shape (50,50)
)
import numpy as np
values = np.array([1.5, 2, 2.3])
col_ids = np.array([[0,0,0,0], [0,0,0,1], [0,0,1,1]])
(np.equal.outer(col_ids, np.arange(len(values))) * values[:,None,None]).sum(0)
Output
array([[5.8, 0. , 0. ],
[5.8, 0. , 0. ],
[3.5, 2.3, 0. ],
[1.5, 4.3, 0. ]])
The only reliably faster solution I could find is numba
(using version 0.55.1
). I thought this implementation would benefit from parallel execution, but I couldn't get any speed up on a 2-core colab instance.
import numba as nb
@nb.njit(parallel=False) # Try parallel=True for multi-threaded execution, no speed up in my benchmarks
def fill(val, ids):
res = np.zeros(ids.shape[::-1])
for i in nb.prange(len(res)):
for j in range(res.shape[1]):
res[i, ids[j,i]] = val[j]
return res
fill(values, col_ids)
Output
array([[5.8, 0. , 0. ],
[5.8, 0. , 0. ],
[3.5, 2.3, 0. ],
[1.5, 4.3, 0. ]])