i need to create an array of all the permutations of the digits 0-9 of size N (input, 1 <= N <= 10).
I've tried this:
np.array(list(itertools.permutations(range(10), n)))
for n=6:
timeit np.array(list(itertools.permutations(range(10), 6)))
on my machine gives:
68.5 ms ± 881 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
But it simply not fast enough. I need it to be below 40ms.
Note: I cannot change the machine from numpy version 1.22.3
CodePudding user response:
Refer to the link provided by @KellyBundy to get a fast method:
def permutations_(n, k):
a = np.zeros((math.perm(n, k), k), np.uint8)
f = 1
for m in range(n - k 1, n 1):
b = a[:f, n - m 1:]
for i in range(1, m):
a[i * f:(i 1) * f, n - m] = i
a[i * f:(i 1) * f, n - m 1:] = b (b >= i)
b = 1
f *= m
return a
Simple test:
In [125]: %timeit permutations_(10, 6)
3.96 ms ± 42.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [128]: np.array_equal(permutations_(10, 6), np.array(list(permutations(range(10), 6))))
Out[128]: True
Old answer
Using itertools.chain.from_iterable
to concatenate iterators of each tuple to construct array lazily can get a little improvement:
In [94]: from itertools import chain, permutations
In [95]: %timeit np.array(list(permutations(range(10), 6)), np.int8)
63.2 ms ± 500 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [96]: %timeit np.fromiter(chain.from_iterable(permutations(range(10), 6)), np.int8).reshape(-1, 6)
28.4 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
@KellyBundy proposed a faster solution in the comments area, using the fast iteration in the bytes
constructor and buffer protocol. It seems that the numpy.fromiter
wasted a lot of time in iteration:
In [98]: %timeit np.frombuffer(bytes(chain.from_iterable(permutations(range(10), 6))), np.int8).reshape(-1, 6)
11.3 ms ± 23.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
However, it should be noted that the above results are read-only (thanks for @MichaelSzczesny's reminder):
In [109]: ar = np.frombuffer(bytes(chain.from_iterable(permutations(range(10), 6))), np.int8).reshape(-1, 6)
In [110]: ar[0, 0] = 1
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In [110], line 1
----> 1 ar[0, 0] = 1
ValueError: assignment destination is read-only