I want to generate a 2D numpy array with elements calculated from their positions. Something like the following code:
import numpy as np
def calculate_element(i, j, other_parameters):
# do something
return value_at_i_j
def main():
arr = np.zeros((M, N)) # (M, N) is the shape of the array
for i in range(M):
for j in range(N):
arr[i][j] = calculate_element(i, j, ...)
This code runs extremely slow since the loops in Python are just not very efficient. Is there any way to do this faster in this case?
By the way, for now I use a workaround by calculating two 2D "index matrices". Something like this:
def main():
index_matrix_i = np.array([range(M)] * N).T
index_matrix_j = np.array([range(N)] * M)
'''
index_matrix_i is like
[[0,0,0,...],
[1,1,1,...],
[2,2,2,...],
...
]
index_matrix_j is like
[[0,1,2,...],
[0,1,2,...],
[0,1,2,...],
...
]
'''
arr = calculate_element(index_matrix_i, index_matrix_j, ...)
Edit1: The code becomes much faster after I apply the "index matrices" trick, so the main question I want to ask is that if there is a way to not use this trick, since it takes more memory. In short, I want to have a solution that is efficient in both time and space.
Edit2: Some examples I tested
# a simple 2D Gaussian
def calculate_element(i, j, i_mid, j_mid, i_sig, j_sig):
gaus_i = np.exp(-((i - i_mid)**2) / (2 * i_sig**2))
gaus_j = np.exp(-((j - j_mid)**2) / (2 * j_sig**2))
return gaus_i * gaus_j
# size of M, N
M, N = 1200, 4000
# use for loops to go through every element
# this code takes ~10 seconds
def main_1():
arr = np.zeros((M, N)) # (M, N) is the shape of the array
for i in range(M):
for j in range(N):
arr[i][j] = calculate_element(i, j, 600, 2000, 300, 500)
# print(arr)
plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()
# use index matrices
# this code takes <1 second
def main_2():
index_matrix_i = np.array([range(M)] * N).T
index_matrix_j = np.array([range(N)] * M)
arr = calculate_element(index_matrix_i, index_matrix_j, 600, 2000, 300, 500)
# print(arr)
plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()
CodePudding user response:
Jitted parallel numba
import numba as nb # tested with numba 0.55.1
@nb.njit(parallel=True)
def calculate_element_nb(i, j, i_mid, j_mid, i_sig, j_sig):
res = np.empty((i,j), np.float32)
for i in nb.prange(res.shape[0]):
for j in range(res.shape[1]):
res[i,j] = np.exp(-(i - i_mid)**2 / (2 * i_sig**2)) * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
return res
M, N = 1200, 4000
calculate_element_nb(M, N, 600, 2000, 300, 500)
# %timeit 10 loops, best of 5: 80.4 ms per loop
plt.figure(figsize=(8, 5))
plt.imshow(calculate_element_nb(M, N, 600, 2000, 300, 500), aspect='auto', origin='lower')
plt.show()