Scipy Multivariate Normal:如何绘制确定性样本?

k88*_*074 3 python random statistics probability scipy

我使用Scipy.stats.multivariate_normal从多元正态分布中提取样本.像这样:

from scipy.stats import multivariate_normal
# Assume we have means and covs
mn = multivariate_normal(mean = means, cov = covs)
# Generate some samples
samples = mn.rvs()
Run Code Online (Sandbox Code Playgroud)

每次运行时样本都不同.如何获得相同的样品?我期待的是:

mn = multivariate_normal(mean = means, cov = covs, seed = aNumber)
Run Code Online (Sandbox Code Playgroud)

要么

samples = mn.rsv(seed = aNumber)
Run Code Online (Sandbox Code Playgroud)

War*_*ser 7

有两种方法:

  1. rvs()方法接受一个random_state参数.它的值可以是整数种子,也可以是实例numpy.random.RandomState.在这个例子中,我使用整数种子:

    In [46]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
    
    In [47]: mn.rvs(size=5, random_state=12345)
    Out[47]: 
    array([[-0.51943872,  1.07094986, -1.0235383 ],
           [ 1.39340583,  4.39561899, -2.77865152],
           [ 0.76902257,  0.63000355,  0.46453938],
           [-1.29622111,  2.25214387,  6.23217368],
           [ 1.35291684,  0.51186476,  1.37495817]])
    
    In [48]: mn.rvs(size=5, random_state=12345)
    Out[48]: 
    array([[-0.51943872,  1.07094986, -1.0235383 ],
           [ 1.39340583,  4.39561899, -2.77865152],
           [ 0.76902257,  0.63000355,  0.46453938],
           [-1.29622111,  2.25214387,  6.23217368],
           [ 1.35291684,  0.51186476,  1.37495817]])
    
    Run Code Online (Sandbox Code Playgroud)
  2. 您可以为numpy的全局随机数生成器设置种子.这是multivariate_normal.rvs()使用if 的生成器random_state:

    In [54]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])
    
    In [55]: np.random.seed(123)
    
    In [56]: mn.rvs(size=5)
    Out[56]: 
    array([[  0.2829785 ,   2.23013222,  -5.42815302],
           [  1.65143654,  -1.2937895 ,  -7.53147357],
           [  1.26593626,  -0.95907779, -12.13339622],
           [ -0.09470897,  -1.51803558,  -4.33370201],
           [ -0.44398196,  -1.4286283 ,   7.45694813]])
    
    In [57]: np.random.seed(123)
    
    In [58]: mn.rvs(size=5)
    Out[58]: 
    array([[  0.2829785 ,   2.23013222,  -5.42815302],
           [  1.65143654,  -1.2937895 ,  -7.53147357],
           [  1.26593626,  -0.95907779, -12.13339622],
           [ -0.09470897,  -1.51803558,  -4.33370201],
           [ -0.44398196,  -1.4286283 ,   7.45694813]])
    
    Run Code Online (Sandbox Code Playgroud)