在numpy 页面上他们给出了示例
s = np.random.dirichlet((10, 5, 3), 20)
Run Code Online (Sandbox Code Playgroud)
这一切都很好、很棒;但是如果您想从 2D alpha 数组生成随机样本怎么办?
alphas = np.random.randint(10, size=(20, 3))
Run Code Online (Sandbox Code Playgroud)
如果您尝试np.random.dirichlet(alphas), np.random.dirichlet([x for x in alphas]), or np.random.dirichlet((x for x in alphas)),
则会产生
ValueError: object too deep for desired array. 唯一有效的似乎是:
y = np.empty(alphas.shape)
for i in xrange(np.alen(alphas)):
y[i] = np.random.dirichlet(alphas[i])
print y
Run Code Online (Sandbox Code Playgroud)
...这对于我的代码结构来说远非理想。为什么会出现这种情况,有人能想到一种更“类似 numpy”的方法吗?
提前致谢。
np.random.dirichlet被编写为生成单个狄利克雷分布的样本。该代码是根据伽玛分布实现的,并且该实现可以用作矢量化代码的基础,以从不同的分布生成样本。下面,采用形状为 (n, k) 的dirichlet_sample数组,其中每一行都是狄利克雷分布的向量。它返回一个形状为 (n, k) 的数组,每一行都是 中相应分布的样本。当作为脚本运行时,它使用和生成样本,以验证它们是否生成相同的样本(最多正常浮点差异)。alphasalphaalphasdirichlet_samplenp.random.dirichlet
import numpy as np
def dirichlet_sample(alphas):
"""
Generate samples from an array of alpha distributions.
"""
r = np.random.standard_gamma(alphas)
return r / r.sum(-1, keepdims=True)
if __name__ == "__main__":
alphas = 2 ** np.random.randint(0, 4, size=(6, 3))
np.random.seed(1234)
d1 = dirichlet_sample(alphas)
print "dirichlet_sample:"
print d1
np.random.seed(1234)
d2 = np.empty(alphas.shape)
for k in range(len(alphas)):
d2[k] = np.random.dirichlet(alphas[k])
print "np.random.dirichlet:"
print d2
# Compare d1 and d2:
err = np.abs(d1 - d2).max()
print "max difference:", err
Run Code Online (Sandbox Code Playgroud)
示例运行:
dirichlet_sample:
[[ 0.38980834 0.4043844 0.20580726]
[ 0.14076375 0.26906604 0.59017021]
[ 0.64223074 0.26099934 0.09676991]
[ 0.21880145 0.33775249 0.44344606]
[ 0.39879859 0.40984454 0.19135688]
[ 0.73976425 0.21467288 0.04556287]]
np.random.dirichlet:
[[ 0.38980834 0.4043844 0.20580726]
[ 0.14076375 0.26906604 0.59017021]
[ 0.64223074 0.26099934 0.09676991]
[ 0.21880145 0.33775249 0.44344606]
[ 0.39879859 0.40984454 0.19135688]
[ 0.73976425 0.21467288 0.04556287]]
max difference: 5.55111512313e-17
Run Code Online (Sandbox Code Playgroud)