Home > front end >  Extending and optimizing 2D grid search code to N-dimensions (using itertools)
Extending and optimizing 2D grid search code to N-dimensions (using itertools)

Time:12-16

I have code for a 2D grid search and it works perfectly. Here is the sample code:

chinp = np.zeros((N,N))
O_M_De = []

for x,y in list(((x,y) for x in range(len(omega_M)) for y in range(len(omega_D)))):


Omin = omega_M[x]
Odin = omega_D[y]
print(Omin, Odin) 


chi = np.sum((dist_data - dist_theo)**2/(chi_err))

chinp[y,x] = chi
    
chi_values1.append(chi)


O_M_De.append((x,y))

My question is, at some point in the future, I may want to perform a grid search over more dimensions. Now if this were the case of 3 dimensions, it would be as simple as adding another variable 'z' in my 'for' statement (line 3). This code would work fine for me to keep adding more dimensions too (i have tried it and it works).

However as you can tell, if I wanted a large number of dimensions to perform a grid search over, it would get a little tedious and inefficient to keep adding variable to my 'for' statement (e.g. for 5D it would go something like 'for v,w,x,y,z in list(((v,w,x,y,z)...').

Just from various google searches, I am under the impression that itertools is very helpful when it comes to performing grid searches however I am still fairly new to programming and unfamiliar with it.

My question is if anyone knows a way (using itertools or some other method I am not aware of) to be able to extend this code to N-dimenions in a more efficient way (i.e. maybe change the 'for' statement so I can grid search over N-dimensions easily without adding on another 'for z in range etc'

Thank you in advance for your help.

CodePudding user response:

You want to take a look at product function from itertools

import itertools

x_list = [0, 1, 2]
y_list = [10, 11, 12]
z_list = [20, 21, 22]

for x, y, z in itertools.product(x_list, y_list, z_list):
    print(x, y, z)
0 10 20
0 10 21
0 10 22
0 11 20
0 11 21
(...)
2 11 21
2 11 22
2 12 20
2 12 21
2 12 22

Note that this will not be the most efficient way. The best results you will get if you add some vectorization (for example using numpy or numba) and parallelism (using multiprocessing or numba).

  • Related