Thé*_*dez 5 python matplotlib histogram scikit-learn gmm
我想用混合一维高斯做一个直方图作为图片。
谢谢孟老师的照片。
我的直方图是这样的:
我有一个文件,其中一列包含大量数据(4,000,000 个数字):
1.727182
1.645300
1.619943
1.709263
1.614427
1.522313
Run Code Online (Sandbox Code Playgroud)
我正在使用以下脚本,并进行了比孟和正义勋爵所做的修改:
from matplotlib import rc
from sklearn import mixture
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import matplotlib.ticker as tkr
import scipy.stats as stats
x = open("prueba.dat").read().splitlines()
f = np.ravel(x).astype(np.float)
f=f.reshape(-1,1)
g = mixture.GaussianMixture(n_components=3,covariance_type='full')
g.fit(f)
weights = g.weights_
means = g.means_
covars = g.covariances_
plt.hist(f, bins=100, histtype='bar', density=True, ec='red', alpha=0.5)
plt.plot(f,weights[0]*stats.norm.pdf(f,means[0],np.sqrt(covars[0])), c='red')
plt.rcParams['agg.path.chunksize'] = 10000
plt.grid()
plt.show()
Run Code Online (Sandbox Code Playgroud)
当我运行脚本时,我有以下情节:
所以,我不知道如何放置必须存在的所有高斯的开始和结束。我是 python 新手,我对使用模块的方式感到困惑。拜托,你能帮助我并指导我如何完成这个情节吗?
多谢
虽然这是一个相当古老的线程,但我想提供我的看法。我相信我的回答对于一些人来说会更容易理解。此外,我还进行了一项测试,以通过 BIC 标准检查所需的组件数量是否具有统计意义。
# import libraries (some are for cosmetics)
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, AutoMinorLocator)
import astropy
from scipy.stats import norm
from sklearn.mixture import GaussianMixture as GMM
import matplotlib as mpl
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams.update({'font.size': 15, 'font.family': 'STIXGeneral', 'mathtext.fontset': 'stix'})
# create the data as in @Meng's answer
x = np.concatenate((np.random.normal(5, 5, 1000), np.random.normal(10, 2, 1000)))
x = x.reshape(-1, 1)
# first of all, let's confirm the optimal number of components
bics = []
min_bic = 0
counter=1
for i in range (10): # test the AIC/BIC metric between 1 and 10 components
gmm = GMM(n_components = counter, max_iter=1000, random_state=0, covariance_type = 'full')
labels = gmm.fit(x).predict(x)
bic = gmm.bic(x)
bics.append(bic)
if bic < min_bic or min_bic == 0:
min_bic = bic
opt_bic = counter
counter = counter + 1
# plot the evolution of BIC/AIC with the number of components
fig = plt.figure(figsize=(10, 4))
ax = fig.add_subplot(1,2,1)
# Plot 1
plt.plot(np.arange(1,11), bics, 'o-', lw=3, c='black', label='BIC')
plt.legend(frameon=False, fontsize=15)
plt.xlabel('Number of components', fontsize=20)
plt.ylabel('Information criterion', fontsize=20)
plt.xticks(np.arange(0,11, 2))
plt.title('Opt. components = '+str(opt_bic), fontsize=20)
# Since the optimal value is n=2 according to both BIC and AIC, let's write down:
n_optimal = opt_bic
# create GMM model object
gmm = GMM(n_components = n_optimal, max_iter=1000, random_state=10, covariance_type = 'full')
# find useful parameters
mean = gmm.fit(x).means_
covs = gmm.fit(x).covariances_
weights = gmm.fit(x).weights_
# create necessary things to plot
x_axis = np.arange(-20, 30, 0.1)
y_axis0 = norm.pdf(x_axis, float(mean[0][0]), np.sqrt(float(covs[0][0][0])))*weights[0] # 1st gaussian
y_axis1 = norm.pdf(x_axis, float(mean[1][0]), np.sqrt(float(covs[1][0][0])))*weights[1] # 2nd gaussian
ax = fig.add_subplot(1,2,2)
# Plot 2
plt.hist(x, density=True, color='black', bins=np.arange(-100, 100, 1))
plt.plot(x_axis, y_axis0, lw=3, c='C0')
plt.plot(x_axis, y_axis1, lw=3, c='C1')
plt.plot(x_axis, y_axis0+y_axis1, lw=3, c='C2', ls='dashed')
plt.xlim(-10, 20)
#plt.ylim(0.0, 2.0)
plt.xlabel(r"X", fontsize=20)
plt.ylabel(r"Density", fontsize=20)
plt.subplots_adjust(wspace=0.3)
plt.show()
plt.close('all')
Run Code Online (Sandbox Code Playgroud)
这一切都是为了重塑。首先,您需要重塑 f 。对于 pdf,请在使用 stats.norm.pdf 之前重新调整形状。同样,在绘图之前进行排序和重塑。
from matplotlib import rc
from sklearn import mixture
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import matplotlib.ticker as tkr
import scipy.stats as stats
# x = open("prueba.dat").read().splitlines()
# create the data
x = np.concatenate((np.random.normal(5, 5, 1000),np.random.normal(10, 2, 1000)))
f = np.ravel(x).astype(np.float)
f=f.reshape(-1,1)
g = mixture.GaussianMixture(n_components=3,covariance_type='full')
g.fit(f)
weights = g.weights_
means = g.means_
covars = g.covariances_
plt.hist(f, bins=100, histtype='bar', density=True, ec='red', alpha=0.5)
f_axis = f.copy().ravel()
f_axis.sort()
plt.plot(f_axis,weights[0]*stats.norm.pdf(f_axis,means[0],np.sqrt(covars[0])).ravel(), c='red')
plt.rcParams['agg.path.chunksize'] = 10000
plt.grid()
plt.show()
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
9147 次 |
| 最近记录: |