I want to build a function GridGen
, so that it does the following.
GridGen([a], -1, 1) --> [{a: -1}, {a: 0}, {a: -1}]
GridGen([a, b], -1, 1) --> [{a: -1, b: -1},{a: 0, b: -1}, {a: 1, b: -1},
{a: -1, b: 0}, {a: 0, b: 0}, {a: 1, b: 0},
{a: -1, b: 1}, {a: 0, b: 1}, {a: 1, b: 1}]
GridGen([a,b,c], -1, 1) --> [{a: -1, b: -1, c: -1},{a: 0, b: -1, c: -1}, {a: 1, b: -1, c:-1}, ... ]
At the moment I achieve this with two functions with a simple recursion.
from sage.all import *
def TensorMergeDict(dicts):
if len(dicts) == 2:
return flatten([[dicts[0][i] | dicts[1][j] for i in range(len(dicts[0]))] for j in range(len(dicts[1]))])
else:
print(TensorMergeDict([dicts[0], TensorMergeDict(dicts[1:])]))
return
def GridGen(vars, minV, maxV, step = 1):
dicts = [[{e: i} for i in range(minV, maxV 1, step)] for e in vars]
return TensorMergeDict(dicts)
where sage
provides the convenient flatten
function to flatten a list.
I wonder if there is better/efficient way to do it? It feels like there should be some existing function in Python
or SageMath
that facilitates such operation.
CodePudding user response:
How about itertools.product
? -
from itertools import product
def grid(keys, lo, hi):
for p in product(range(lo, hi 1), repeat=len(keys)):
yield {k:v for k,v in zip(keys, p)}
for x in grid("abc", -1, 1):
print(x)
{'a': -1, 'b': -1, 'c': -1}
{'a': -1, 'b': -1, 'c': 0}
{'a': -1, 'b': -1, 'c': 1}
{'a': -1, 'b': 0, 'c': -1}
{'a': -1, 'b': 0, 'c': 0}
{'a': -1, 'b': 0, 'c': 1}
{'a': -1, 'b': 1, 'c': -1}
{'a': -1, 'b': 1, 'c': 0}
{'a': -1, 'b': 1, 'c': 1}
{'a': 0, 'b': -1, 'c': -1}
{'a': 0, 'b': -1, 'c': 0}
{'a': 0, 'b': -1, 'c': 1}
{'a': 0, 'b': 0, 'c': -1}
{'a': 0, 'b': 0, 'c': 0}
{'a': 0, 'b': 0, 'c': 1}
{'a': 0, 'b': 1, 'c': -1}
{'a': 0, 'b': 1, 'c': 0}
{'a': 0, 'b': 1, 'c': 1}
{'a': 1, 'b': -1, 'c': -1}
{'a': 1, 'b': -1, 'c': 0}
{'a': 1, 'b': -1, 'c': 1}
{'a': 1, 'b': 0, 'c': -1}
{'a': 1, 'b': 0, 'c': 0}
{'a': 1, 'b': 0, 'c': 1}
{'a': 1, 'b': 1, 'c': -1}
{'a': 1, 'b': 1, 'c': 0}
{'a': 1, 'b': 1, 'c': 1}
keys
can be string or array -
for x in grid(["foo", "bar"], -1, 1):
print(x)
{'foo': -1, 'bar': -1}
{'foo': -1, 'bar': 0}
{'foo': -1, 'bar': 1}
{'foo': 0, 'bar': -1}
{'foo': 0, 'bar': 0}
{'foo': 0, 'bar': 1}
{'foo': 1, 'bar': -1}
{'foo': 1, 'bar': 0}
{'foo': 1, 'bar': 1}