重新定义 R 的 nnet::multinom Predict.multinom 预测方法以支持 type="link"

Tom*_*ers 3 r nnet multinomial emmeans marginal-effects

我希望nnet::multinom新包支持 R 的函数marginaleffects,但marginaleffects::predictions()依赖predict()建模包提供的方法来计算响应和链接尺度的预测值。nnet::multinom然而,在 的情况下,predict()提供的方法nnet不支持链接规模的预测 - 它仅支持type="probs"type="class"https://github.com/vincentarelbundock/marginaleffects/issues/404。所以我想重新定义该nnet::multinom predict.multinom方法,以便它也支持type="link"(在该包的原始命名空间中,以便该marginaleffects包也将其视为已被重新定义)。有什么办法可以做到这一点吗?

作为参考,该predict.multinom方法(https://github.com/cran/nnet/blob/master/R/multinom.R)现在看起来像

predict.multinom <- function(object, newdata, type=c("class","probs"), ...)
{
    if(!inherits(object, "multinom")) stop("not a \"multinom\" fit")
    type <- match.arg(type)
    if(missing(newdata)) Y <- fitted(object)
    else {
        newdata <- as.data.frame(newdata)
        rn <- row.names(newdata)
        Terms <- delete.response(object$terms)
        m <- model.frame(Terms, newdata, na.action = na.omit,
                         xlev = object$xlevels)
        if (!is.null(cl <- attr(Terms, "dataClasses")))
            .checkMFClasses(cl, m)
        keep <- match(row.names(m), rn)
        X <- model.matrix(Terms, m, contrasts = object$contrasts)
        Y1 <- predict.nnet(object, X)
        Y <- matrix(NA, nrow(newdata), ncol(Y1),
                    dimnames = list(rn, colnames(Y1)))
        Y[keep, ] <- Y1
    }
    switch(type, class={
        if(length(object$lev) > 2L)
            Y <- factor(max.col(Y), levels=seq_along(object$lev),
                        labels=object$lev)
        if(length(object$lev) == 2L)
            Y <- factor(1 + (Y > 0.5), levels=1L:2L, labels=object$lev)
        if(length(object$lev) == 0L)
            Y <- factor(max.col(Y), levels=seq_along(object$lab),
                        labels=object$lab)
    }, probs={})
    drop(Y)
}
Run Code Online (Sandbox Code Playgroud)

predict.nnethttps://github.com/cran/nnet/blob/master/R/nnet.R )由以下给出

predict.nnet <- function(object, newdata, type=c("raw","class"), ...)
{
    if(!inherits(object, "nnet")) stop("object not of class \"nnet\"")
    type <- match.arg(type)
    if(missing(newdata)) z <- fitted(object)
    else {
        if(inherits(object, "nnet.formula")) { #
            ## formula fit
            newdata <- as.data.frame(newdata)
            rn <- row.names(newdata)
            ## work hard to predict NA for rows with missing data
            Terms <- delete.response(object$terms)
            m <- model.frame(Terms, newdata, na.action = na.omit,
                             xlev = object$xlevels)
            if (!is.null(cl <- attr(Terms, "dataClasses")))
                .checkMFClasses(cl, m)
            keep <- match(row.names(m), rn)
            x <- model.matrix(Terms, m, contrasts = object$contrasts)
            xint <- match("(Intercept)", colnames(x), nomatch=0L)
            if(xint > 0L) x <- x[, -xint, drop=FALSE] # Bias term is used for intercepts
        } else {
            ## matrix ...  fit
            if(is.null(dim(newdata)))
                dim(newdata) <- c(1L, length(newdata)) # a row vector
            x <- as.matrix(newdata)     # to cope with dataframes
            if(any(is.na(x))) stop("missing values in 'x'")
            keep <- 1L:nrow(x)
            rn <- rownames(x)
        }
        ntr <- nrow(x)
        nout <- object$n[3L]
        .C(VR_set_net,
           as.integer(object$n), as.integer(object$nconn),
           as.integer(object$conn), rep(0.0, length(object$wts)),
           as.integer(object$nsunits), as.integer(0L),
           as.integer(object$softmax), as.integer(object$censored))
        z <- matrix(NA, nrow(newdata), nout,
                    dimnames = list(rn, dimnames(object$fitted.values)[[2L]]))
        z[keep, ] <- matrix(.C(VR_nntest,
                               as.integer(ntr),
                               as.double(x),
                               tclass = double(ntr*nout),
                               as.double(object$wts)
                               )$tclass, ntr, nout)
        .C(VR_unset_net)
    }
    switch(type, raw = z,
           class = {
               if(is.null(object$lev)) stop("inappropriate fit for class")
               if(ncol(z) > 1L) object$lev[max.col(z)]
               else object$lev[1L + (z > 0.5)]
           })
}
Run Code Online (Sandbox Code Playgroud)

我希望我可以predict.multinom通过函数predict.mblogithttps://github.com/melff/mclogit/blob/master/pkg/R/mblogit.R)或接近它的东西来覆盖该函数(可能需要一些小的编辑,由于 mbgit 和 nnet 对象的结构略有不同):

predict.mblogit <- function(object, newdata=NULL,type=c("link","response"),se.fit=FALSE,...){
  
  type <- match.arg(type)
  mt <- terms(object)
  rhs <- delete.response(mt)
  if(missing(newdata)){
    m <- object$model
    na.act <- object$na.action
  }
  else{
    m <- model.frame(rhs,data=newdata,na.action=na.exclude)
    na.act <- attr(m,"na.action")
  }
  X <- model.matrix(rhs,m,
                    contrasts.arg=object$contrasts,
                    xlev=object$xlevels
  )
  rn <- rownames(X)
  D <- object$D
  XD <- X%x%D
  rspmat <- function(x){
    y <- t(matrix(x,nrow=nrow(D)))
    colnames(y) <- rownames(D)
    y
  }
  
  eta <- c(XD %*% coef(object))
  eta <- rspmat(eta)
  rownames(eta) <- rn
  if(se.fit){
    V <- vcov(object)
    stopifnot(ncol(XD)==ncol(V))
  }
  
  if(type=="response") {
    exp.eta <- exp(eta)
    sum.exp.eta <- rowSums(exp.eta)
    p <- exp.eta/sum.exp.eta
    
    if(se.fit){
      p.long <- as.vector(t(p))
      s <- rep(1:nrow(X),each=nrow(D))
      
      wX <- p.long*(XD - rowsum(p.long*XD,s)[s,,drop=FALSE])
      se.p.long <- sqrt(rowSums(wX * (wX %*% V)))
      se.p <- rspmat(se.p.long)
      rownames(se.p) <- rownames(p)
      if(is.null(na.act))
        list(fit=p,se.fit=se.p)
      else
        list(fit=napredict(na.act,p),
             se.fit=napredict(na.act,se.p))
    }
    else {
      if(is.null(na.act)) p
      else napredict(na.act,p)
    }
  }
  else if(se.fit) {
    se.eta <- sqrt(rowSums(XD * (XD %*% V)))
    se.eta <- rspmat(se.eta)
    eta <- eta[,-1,drop=FALSE]
    se.eta <- se.eta[,-1,drop=FALSE]
    if(is.null(na.act))
        list(fit=eta,se.fit=se.eta) 
    else
      list(fit=napredict(na.act,eta),
           se.fit=napredict(na.act,se.eta))
  }
  else {
      eta <- eta[,-1,drop=FALSE]
      if(is.null(na.act)) eta
      else napredict(na.act,eta)
  }
}
Run Code Online (Sandbox Code Playgroud)

我想要实现的可重复示例:

# data=SARS-CoV2 coronavirus variants (variant) through time (collection_date_num)
# in India, count=actual count (nr of sequenced genomes)
dat = read.csv("https://www.dropbox.com/s/u27cn44p5srievq/dat.csv?dl=1")
dat$collection_date = as.Date(dat$collection_date)
dat$collection_date_num = as.numeric(dat$collection_date) # numeric version of date, to convert back to date: as.Date(dat$collection_date_num, origin="1970-01-01")
dat$variant = factor(dat$variant)

# 1. multinom::net multinomial fit ####
library(nnet)
library(splines)
set.seed(1)
fit_nnet = nnet::multinom(variant ~ ns(collection_date_num, df=2), 
                          weights=count, data=dat)
summary(fit_nnet)

# 2. predicted probabilities & 95% CLs at maximum date calculated using emmeans: works, but slow for large models ####
library(emmeans)
multinom_emmeans = emmeans(fit_nnet, ~ variant,  
                       mode = "prob",
                       at=list(collection_date_num = 
                                 max(dat$collection_date_num)))
multinom_emmeans
# variant               prob       SE df lower.CL upper.CL
# Alpha             0.00e+00 0.00e+00 33 0.00e+00 0.00e+00
# Beta              0.00e+00 0.00e+00 33 0.00e+00 0.00e+00
# Delta             7.73e-06 1.17e-06 33 5.34e-06 1.01e-05
# Omicron (BA.1)    1.82e-04 6.42e-05 33 5.14e-05 3.13e-04
# Omicron (BA.2)    1.76e-01 7.45e-03 33 1.61e-01 1.91e-01
# Omicron (BA.2.74) 9.03e-02 7.98e-03 33 7.41e-02 1.07e-01
# Omicron (BA.2.75) 1.68e-01 1.90e-02 33 1.30e-01 2.07e-01
# Omicron (BA.2.76) 2.89e-01 1.35e-02 33 2.62e-01 3.16e-01
# Omicron (BA.3)    1.34e-02 2.10e-03 33 9.10e-03 1.76e-02
# Omicron (BA.4)    1.67e-02 2.47e-03 33 1.17e-02 2.17e-02
# Omicron (BA.5)    2.03e-01 1.08e-02 33 1.81e-01 2.25e-01
# Other             4.23e-02 3.15e-03 33 3.59e-02 4.87e-02
#
# Confidence level used: 0.95 


# 3. predicted probabilities & 95% CLs at maximum date calculated using marginaleffects: does not work because of lack of a predict.multinom method supporting type="link" ####

library(marginaleffects)
multinom_preds_marginaleffects = predictions(fit_nnet,
                                         newdata = datagrid(collection_date_num = 
                                                              max(dat$collection_date_num)),
                                         type="link", # not supported by predict.multinom
                                         transform_post = insight::link_inverse(fit_nnet))
# Error: The `type` argument for models of class `multinom` must be an element of: probs
# PS: desired output should match emmeans output above
Run Code Online (Sandbox Code Playgroud)

Hon*_*Ooi 5

重新定义包中方法的方法是使用assignInNamespace. 然而,假设这是最终将公开的另一个包的一部分,这有点粗鲁,因为你正在践踏别人的代码。特别是,如果您打算将其放在 CRAN 上,您可能会遇到说服 CRAN 审阅者它没问题的问题。

更好的解决方案是创建一个调用原始方法的包装方法。为此,您还需要创建一个包装multinom函数,以便找到正确的包命名空间。草图实现如下所示。

multinom <- function(...)
{
    nnet::multinom(...)
}

# this is the link function for multinomial
# = generalized logit
inverse_softMax <- function(mu) {
  log_mu <- log(mu)
  return(sweep(log_mu, 1, STATS=rowMeans(log_mu), FUN="-")) # we let the log(odds) sum to zero - these predictions are referred to as type="latent" in the emmeans package
}

predict.multinom <- function(object, newdata, type=c("probs", "response", "latent", "link") # probs==response, latent==link
{
    type <- match.arg(type)
    if (type == "probs"|type == "response")
        return(nnet:::predict.multinom(object, newdata, type="probs"))

    mu <- nnet:::predict.multinom(object, newdata, type="probs")
    return(inverse_softMax(mu))

}
Run Code Online (Sandbox Code Playgroud)