可以将 numpy 数组添加到 python 集合中吗?

CBo*_*man 7 python casting tuples numpy set

我知道为了将一个元素添加到集合中,它必须是可散列的,而 numpy 数组似乎不是。这给我带来了一些问题,因为我有以下代码:

fill_set = set()
for i in list_of_np_1D:
    vecs = i + np_2D
    for j in range(N):
        tup = tuple(vecs[j,:])
        fill_set.add(tup)

# list_of_np_1D is a list of 1D numpy arrays
# np_2D is a 2D numpy array
# np_2D could also be converted to a list of 1D arrays if it helped.
Run Code Online (Sandbox Code Playgroud)

我需要让它运行得更快,将近 50% 的运行时间用于将 2D numpy 数组的切片转换为元组,以便将它们添加到集合中。

所以我一直在试图找出以下内容

  • 有没有什么方法可以使 numpy 数组或类似 numpy 数组(具有向量加法)功能的东西可哈希,以便可以将它们添加到集合中?
  • 如果没有,有没有办法可以加快元组转换的过程?

谢谢你的帮助!

HYR*_*YRY 3

首先创建一些数据:

import numpy as np
np.random.seed(1)
list_of_np_1D = np.random.randint(0, 5, size=(500, 6))
np_2D = np.random.randint(0, 5, size=(20, 6))
Run Code Online (Sandbox Code Playgroud)

运行你的代码:

%%time
fill_set = set()
for i in list_of_np_1D:
    vecs = i + np_2D
    for v in vecs:
        tup = tuple(v)
        fill_set.add(tup)
res1 = np.array(list(fill_set))
Run Code Online (Sandbox Code Playgroud)

输出:

CPU times: user 161 ms, sys: 2 ms, total: 163 ms
Wall time: 167 ms
Run Code Online (Sandbox Code Playgroud)

这是一个加速版本,它使用广播.view()方法将数据类型转换为字符串,调用后将set()字符串转换回数组:

%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
stype = "S%d" % (r.itemsize * np_2D.shape[1])
fill_set2 = set(r.ravel().view(stype).tolist())
res2 = np.zeros(len(fill_set2), dtype=stype)
res2[:] = list(fill_set2)
res2 = res2.view(r.dtype).reshape(-1, np_2D.shape[1])
Run Code Online (Sandbox Code Playgroud)

输出:

CPU times: user 13 ms, sys: 1 ms, total: 14 ms
Wall time: 14.6 ms
Run Code Online (Sandbox Code Playgroud)

检查结果:

np.all(res1[np.lexsort(res1.T), :] == res2[np.lexsort(res2.T), :])
Run Code Online (Sandbox Code Playgroud)

您还可以使用lexsort()删除重复数据:

%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
r = r.reshape(-1, r.shape[-1])

r = r[np.lexsort(r.T)]
idx = np.where(np.all(np.diff(r, axis=0) == 0, axis=1))[0] + 1
res3 = np.delete(r, idx, axis=0)
Run Code Online (Sandbox Code Playgroud)

输出:

CPU times: user 13 ms, sys: 3 ms, total: 16 ms
Wall time: 16.1 ms
Run Code Online (Sandbox Code Playgroud)

检查结果:

np.all(res1[np.lexsort(res1.T), :] == res3)
Run Code Online (Sandbox Code Playgroud)