加速Python/NumPy中的多项随机样本

mvd*_*mvd 1 python optimization numpy vectorization scipy

我正在通过一组概率从多项分布生成绘制向量probs,其中每个绘制是probs选择的条目的索引:

import numpy as np
def sample_mult(K, probs):
    result = np.zeros(num_draws, dtype=np.int32)
    for n in xrange(K):
        draws = np.random.multinomial(1, probs)
        result[n] = np.where(draws == 1)[0][0]
    return result
Run Code Online (Sandbox Code Playgroud)

这可以加快吗?np.random.multinomial一遍又一遍地调用似乎效率低下(也np.where可能很慢).

timeitThe slowest run took 6.72 times longer than the fastest. This could mean that an intermediate result is being cached 100000 loops, best of 3: 18.9 µs per loop

Div*_*kar 6

您可以使用该size选项包含np.random.multinomial 随机样本行,而不是仅使用默认的一行输出,size=1然后用于.argmax(1)模拟np.where()[0][0]行为.

因此,我们会有一个矢量化解决方案,就像这样 -

result = (np.random.multinomial(1,probs,size=K)==1).argmax(1)
Run Code Online (Sandbox Code Playgroud)