Jac*_*tta 10 python scikit-learn
我有一个".dat"文件,其中保存了X和Y的值(所以一个元组(n,2),其中n是行数).
import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate as interp
from sklearn import linear_model
in_file = open(path,"r")
text = np.loadtxt(in_file)
in_file.close()
x = np.array(text[:,0])
y = np.array(text[:,1])
Run Code Online (Sandbox Code Playgroud)
我创建了一个实例linear_model.LinearRegression(),但是当我调用.fit(x,y)我得到的方法时
IndexError:元组索引超出范围
regr = linear_model.LinearRegression()
regr.fit(x,y)
Run Code Online (Sandbox Code Playgroud)
我做错了什么?
Irs*_*hat 16
线性回归期望X作为具有两个维度的数组,并且内部需要X.shape[1]初始化np.ones数组.因此转换X为nx1数组就可以了.所以,替换:
regr.fit(x,y)
Run Code Online (Sandbox Code Playgroud)
通过:
regr.fit(x[:,np.newaxis],y)
Run Code Online (Sandbox Code Playgroud)
这将解决问题.演示:
>>> from sklearn import datasets
>>> from sklearn import linear_model
>>> clf = linear_model.LinearRegression()
>>> iris=datasets.load_iris()
>>> X=iris.data[:,3]
>>> Y=iris.target
>>> clf.fit(X,Y) # This will throw an error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/lib/python2.7/dist-packages/sklearn/linear_model/base.py", line 363, in fit
X, y, self.fit_intercept, self.normalize, self.copy_X)
File "/usr/lib/python2.7/dist-packages/sklearn/linear_model/base.py", line 103, in center_data
X_std = np.ones(X.shape[1])
IndexError: tuple index out of range
>>> clf.fit(X[:,np.newaxis],Y) # This will work properly
LinearRegression(copy_X=True, fit_intercept=True, normalize=False)
Run Code Online (Sandbox Code Playgroud)
要绘制回归线,请使用以下代码:
>>> from matplotlib import pyplot as plt
>>> plt.scatter(X, Y, color='red')
<matplotlib.collections.PathCollection object at 0x7f76640e97d0>
>>> plt.plot(X, clf.predict(X[:,np.newaxis]), color='blue')
<matplotlib.lines.Line2D object at 0x7f7663f9eb90>
>>> plt.show()
Run Code Online (Sandbox Code Playgroud)

| 归档时间: |
|
| 查看次数: |
8829 次 |
| 最近记录: |