创建大小为 N 的数字 0-9 的所有排列 - 使用 numpy\scipy 尽可能优化

Guy*_*ash 2 python numpy scipy

我需要创建一个大小为 N 的数字 0-9 的所有排列的数组(输入,1 <= N <= 10)。

\n

我试过这个:

\n
np.array(list(itertools.permutations(range(10), n)))\n
Run Code Online (Sandbox Code Playgroud)\n

对于 n=6:

\n
timeit np.array(list(itertools.permutations(range(10), 6)))\n
Run Code Online (Sandbox Code Playgroud)\n

在我的机器上给出:

\n
68.5 ms \xc2\xb1 881 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 10 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

但它根本不够快。\n我需要它低于 40 毫秒。

\n

注意:\n我无法从 numpy 版本 1.22.3 更改机器

\n

Mec*_*Pig 7

参考@KellyBundy提供的链接获取快速方法:

\n
def permutations_(n, k):\n    if k == 0:\n        return np.empty((1, 0), np.uint8)\n\n    shape = (math.perm(n, k), k)\n    out = np.zeros(shape, np.uint8)\n    out[:n - k + 1, -1] = np.arange(n - k + 1, dtype=np.uint8)\n\n    start = n - k + 1\n    for col in reversed(range(1, k)):\n        block = out[:start, col:]\n        length = start\n        for i in range(1, n - col + 1):\n            stop = start + length\n            out[start:stop, col:] = block + (block >= i)\n            out[start:stop, col - 1] = i\n            start = stop\n        block += 1  # block is a sub-view on `out`\n\n    return out\n
Run Code Online (Sandbox Code Playgroud)\n

简单测试:

\n
In [125]: %timeit permutations_(10, 6)\n3.73 ms \xc2\xb1 30.2 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 100 loops each)\n\nIn [128]: np.array_equal(permutations_(10, 6), np.array(list(permutations(range(10), 6))))\nOut[128]: True\n
Run Code Online (Sandbox Code Playgroud)\n
\n

旧答案

\n

使用itertools.chain.from_iterable连接每个元组的迭代器来延迟构造数组可以得到一点改进:

\n
In [94]: from itertools import chain, permutations\n\nIn [95]: %timeit np.array(list(permutations(range(10), 6)), np.int8)\n63.2 ms \xc2\xb1 500 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 10 loops each)\n\nIn [96]: %timeit np.fromiter(chain.from_iterable(permutations(range(10), 6)), np.int8).reshape(-1, 6)\n28.4 ms \xc2\xb1 110 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 10 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

@KellyBundy 在评论区提出了一个更快的解决方案,使用bytes构造函数和缓冲区协议中的快速迭代。看来numpy.fromiter在迭代中浪费了很多时间:

\n
In [98]: %timeit np.frombuffer(bytes(chain.from_iterable(permutations(range(10), 6))), np.int8).reshape(-1, 6)\n11.3 ms \xc2\xb1 23.9 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 100 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n

不过需要注意的是,上面的结果是只读的(感谢@MichaelSzczesny\的提醒):

\n
In [109]: ar = np.frombuffer(bytes(chain.from_iterable(permutations(range(10), 6))), np.int8).reshape(-1, 6)\n\nIn [110]: ar[0, 0] = 1\n---------------------------------------------------------------------------\nValueError                                Traceback (most recent call last)\nCell In [110], line 1\n----> 1 ar[0, 0] = 1\n\nValueError: assignment destination is read-only\n
Run Code Online (Sandbox Code Playgroud)\n