mot*_*iur 3 matlab regression machine-learning glm cross-validation
我正在使用广义线性模型进行回归,但使用该crossVal函数却措手不及。到目前为止,我的实现;
x = 'Some dataset, containing the input and the output'
X = x(:,1:7);
Y = x(:,8);
cvpart = cvpartition(Y,'holdout',0.3);
Xtrain = X(training(cvpart),:);
Ytrain = Y(training(cvpart),:);
Xtest = X(test(cvpart),:);
Ytest = Y(test(cvpart),:);
mdl = GeneralizedLinearModel.fit(Xtrain,Ytrain,'linear','distr','poisson');
Ypred = predict(mdl,Xtest);
res = (Ypred - Ytest);
RMSE_test = sqrt(mean(res.^2));
Run Code Online (Sandbox Code Playgroud)
下面的代码用于计算从该链接获得的多元回归的交叉验证。我想要类似的广义线性模型。
c = cvpartition(Y,'k',10);
regf=@(Xtrain,Ytrain,Xtest)(Xtest*regress(Ytrain,Xtrain));
cvMse = crossval('mse',X,Y,'predfun',regf)
Run Code Online (Sandbox Code Playgroud)
您可以手动执行交叉验证过程(为每个折叠训练模型,预测结果,计算错误,然后报告所有折叠的平均值),也可以使用CROSSVAL函数将整个过程包装在一个调用中。
举个例子,我将首先加载并准备一个数据集(“统计工具箱”随附的汽车数据集的子集):
% load regression dataset
load carsmall
X = [Acceleration Cylinders Displacement Horsepower Weight];
Y = MPG;
% remove instances with missing values
missIdx = isnan(Y) | any(isnan(X),2);
X(missIdx,:) = [];
Y(missIdx) = [];
clearvars -except X Y
Run Code Online (Sandbox Code Playgroud)
在这里,我们将使用cvpartition(未分层)使用k倍交叉验证对数据进行手动分区。对于每一折,我们使用训练数据训练GLM模型,然后使用该模型预测测试数据的输出。接下来,我们计算并存储此折叠的回归均方误差。最后,我们报告所有分区的平均RMSE。
% partition data into 10 folds
K = 10;
cv = cvpartition(numel(Y), 'kfold',K);
mse = zeros(K,1);
for k=1:K
% training/testing indices for this fold
trainIdx = cv.training(k);
testIdx = cv.test(k);
% train GLM model
mdl = GeneralizedLinearModel.fit(X(trainIdx,:), Y(trainIdx), ...
'linear', 'Distribution','poisson');
% predict regression output
Y_hat = predict(mdl, X(testIdx,:));
% compute mean squared error
mse(k) = mean((Y(testIdx) - Y_hat).^2);
end
% average RMSE across k-folds
avrg_rmse = mean(sqrt(mse))
Run Code Online (Sandbox Code Playgroud)
在这里,我们可以简单地使用适当的函数句柄调用CROSSVAL,该函数句柄在给定一组训练/测试实例的情况下计算回归输出。请参阅文档页面以了解参数。
% prediction function given training/testing instances
fcn = @(Xtr, Ytr, Xte) predict(...
GeneralizedLinearModel.fit(Xtr,Ytr,'linear','distr','poisson'), ...
Xte);
% perform cross-validation, and return average MSE across folds
mse = crossval('mse', X, Y, 'Predfun',fcn, 'kfold',10);
% compute root mean squared error
avrg_rmse = sqrt(mse)
Run Code Online (Sandbox Code Playgroud)
与之前相比,您应该获得相似的结果(由于交叉验证所涉及的随机性,当然会有一点不同)。
| 归档时间: |
|
| 查看次数: |
5449 次 |
| 最近记录: |