从 Numpy 数组的索引中采样的有效方法?

Cup*_*tor 5 python random numpy sampling

我想从 2D Numpy 数组的索引中采样,考虑到每个索引都由该数组内的数字加权。numpy.random.choice然而,我知道它的方式不返回索引,而是返回数字本身。有什么有效的方法吗?

这是我的代码:

import numpy as np
A=np.arange(1,10).reshape(3,3)
A_flat=A.flatten()
d=np.random.choice(A_flat,size=10,p=A_flat/float(np.sum(A_flat)))
print d
Run Code Online (Sandbox Code Playgroud)

Bi *_*ico 2

你可以这样做:

import numpy as np

def wc(weights):
    cs = np.cumsum(weights)
    idx = cs.searchsorted(np.random.random() * cs[-1], 'right')
    return np.unravel_index(idx, weights.shape)
Run Code Online (Sandbox Code Playgroud)

请注意,累积和是其中最慢的部分,因此如果您需要对同一数组重复执行此操作,我建议提前计算累积和并重用它。

  • @Naji,“cs”已排序,“searchsorted()”利用它进行二分搜索 - 仅需要“O(log(len(weights)))”比较。非常便宜。 (2认同)