使用 mlr 预测计数

Rob*_*ong 2 r machine-learning mlr

我正在使用学习器regr.gbm来预测计数。在 之外mlrgbm直接使用包,我使用distribution = "poisson"predict.gbm,使用type = "response",返回原始比例的预测,但是我注意到,当我使用 执行此操作时mlr,预测似乎在对数比例上:

     truth    response
913      4  0.67348708
914      1  0.28413256
915      3  0.41871237
916      1  0.13027792
2101     1 -0.02092168
2102     2  0.23394970
Run Code Online (Sandbox Code Playgroud)

然而,“真相”不在对数尺度上,所以我担心超参数调整例程mlr不起作用。为了比较,这是我得到的输出distribution = "gaussian"

     truth response
913      4 2.028177
914      1 1.334658
915      3 1.552846
916      1 1.153072
2101     1 1.006362
2102     2 1.281811
Run Code Online (Sandbox Code Playgroud)

处理这个问题的最佳方法是什么?

小智 5

发生这种情况是因为gbm默认情况下对链接函数规模(log用于distribution = "poisson")进行预测。这是由type参数控制的gbm::predict.gbm(参见该函数的帮助页面)。不幸的是mlr,默认情况下不提供更改此参数(在 mlr 错误跟踪器中报告)。现在的解决方法是手动添加此参数:

lrn <- makeLearner("regr.gbm", distribution = "poisson")
lrn$par.set <- c(lrn$par.set,
  makeParamSet(
    makeDiscreteLearnerParam("type", c("link", "response"),
      default = "link", when = "predict", tunable = FALSE)))
lrn <- setHyperPars(lrn, type = "response")

# show that it works:
counttask <- makeRegrTask("counttask", getTaskData(pid.task),
  target = "pregnant")
pred <- predict(train(lrn, counttask), counttask)
pred
Run Code Online (Sandbox Code Playgroud)

请注意,在对计数数据调整参数时,默认回归度量(均方误差)可能会过分强调对具有大计数值的数据点的拟合。预测“10”而不是“1”的平方误差与预测“1010”而不是“1001”的误差相同,但根据您的目标,您可能希望在此示例中对第一个错误施加更多权重。

一个可能的解决方案是使用(标准化)平均泊松对数似然作为度量:

poisllmeasure = makeMeasure(
  id = "poissonllnorm",
  minimize = FALSE,
  best = 0,
  worst = -Inf,
  properties = "regr",
  name = "Mean Poisson Log Likelihood",
  note = "For count data. Normalized to 0 for perfect fit.",
  fun = function(task, model, pred, feats, extra.args) {
    mean(dpois(pred$data$truth, pred$data$response, log = TRUE) -
      dpois(pred$data$truth, pred$data$truth, log = TRUE))
})
# example
performance(pred, poisllmeasure)
Run Code Online (Sandbox Code Playgroud)

通过将其提供给 中的measures参数,此度量可用于调整tuneParams()。(请注意,您将不得不放弃它在列表中:tuneParams(... measures = list(poisllmeasure) ...)