更快的 smsurv 功能

Rik*_*Rik 6 performance r

我正在尝试使 R 函数更加高效。下面找到一个工作示例。

    smsurv <- function(Time,Status,X,beta,w,model){    
    death_point <- sort(unique(subset(Time, Status==1)))
    if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))  
    n <- length(death_point)
    lambda <- numeric(n)
    for(i in 1: n){
      if(model=='ph')  temp <- sum(as.numeric(Time>=death_point[i])*w*drop(coxexp))
      if(model=='aft')  temp <- sum(as.numeric(Time>=death_point[i])*w)
      lambda[i] <- sum(Status*as.numeric(Time==death_point[i]))/temp
    }
    HHazard <- numeric()
    for(i in 1:length(Time)){
      HHazard[i] <- sum(as.numeric(Time[i]>=death_point)*lambda)
      if(Time[i]>max(death_point))HHazard[i] <- Inf
      if(Time[i]<min(death_point))HHazard[i] <- 0
    }
    survival <- exp(-HHazard)
    list(survival=survival)
  }

nr_obs = 50000

Time_input <- rnorm(nr_obs, mean = 100, sd = 36)
Status_input <- sample(c(0,1), replace=TRUE, size=nr_obs)
w_input <- Status_input

# Let's suppose there are 9 variables (first column denotes the intercept)
n_variables <- 9
X_input <- matrix(rnorm(nr_obs*n_variables),nr_obs)
X_input <- cbind(Intercept = rep(1, nrow(X_input)), X_input) 

beta_input <- runif(n_variables, min = -1, max = 1)
model_input <- "ph"
output <- smsurv(Time_input,Status_input,X_input,beta_input,w_input,model_input)
Run Code Online (Sandbox Code Playgroud)

我已经尝试用 lapply 和 sapply 替换 for 循环,但这实际上使函数变得更慢:

    smsurv2 <- function(Time,Status,X,beta,w,model){    
    death_point <- sort(unique(subset(Time, Status==1)))
    if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))  
    if(model_input=='ph') lambda =unlist(lapply(death_point, function(z) sum(Status_input*as.numeric(Time_input==z))/ sum(as.numeric(Time_input>=z)*w_input*drop(coxexp))))
    if(model=='aft') lambda =unlist( lapply(death_point, function(z) sum(Status_input*as.numeric(Time_input==z))/ sum(as.numeric(Time_input>=z)*w_input)))
    HHazard <- unlist(lapply(Time, function(t) {sum(as.numeric(t>=death_point)*lambda)}))
    HHazard[Time > max(death_point)] <- Inf
    HHazard[Time < min(death_point)] <- 0

    survival <- exp(-HHazard)
    list(survival=survival)
  }

smsurv3 <- function(Time, Status, X, beta, w, model){
  death_point <- sort(unique(subset(Time, Status==1)))
  if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))
  lambda <- sapply(death_point, function(dp) {return(sum(Status*as.numeric(Time==dp))/sum(as.numeric(Time>=dp)*w*drop(coxexp)))})
  HHazard <- sapply(Time, function(t){return(sum(as.numeric(t>=death_point)*lambda))})
  HHazard[Time > max(death_point)] <- Inf
  HHazard[Time < min(death_point)] <- 0

  survival <- exp(-HHazard)
  list(survival=survival)
}
Run Code Online (Sandbox Code Playgroud)

考虑到这一点,有人还有其他我可以尝试的建议吗?我最近阅读了有关 rcpp 包的信息,但我不确定如何用 C 代码替换 for 循环。任何建议都非常受欢迎。

Rui*_*das 13

下面的函数smsurv2速度更快,结果与问题的函数相同。

以下是我所做的一些更改。

  • subset慢,索引子集更快;
  • coxexp我从第一个循环中删除了 的计算,for使循环代码更简单,它总是乘以coxexp,它可以是 1 的向量,从而删除测试if
  • seq_along1:n比;更安全
  • 从第二个循环中删除了对 和 0 的赋值Inf,这可以在循环外进行向量化;
  • as.numeric永远不需要,因此消除了几个函数调用。

smsurv2 <- function(Time, Status, X, beta, w, model){
  death_point <- Time[Status == 1] |> unique() |> sort()
  n <- length(death_point)
  lambda <- numeric(n)
  if(model == 'ph') {
    coxexp <- (exp(beta %*% t(X[, -1]))) |> drop()
  } else if(model == 'aft') {
    coxexp <- rep(1, length(Time))
  }
  for(i in seq_along(death_point)){
    temp <- sum((Time >= death_point[i]) * w * coxexp)
    lambda[i] <- sum(Status * (Time == death_point[i])) / temp
  }
  HHazard <- numeric(length(Time))
  for(i in seq_along(Time)){
    HHazard[i] <- sum((Time[i] >= death_point) * lambda)
  }
  HHazard[ Time > max(death_point) ] <- Inf
  HHazard[ Time < min(death_point) ] <- 0
  survival <- exp(-HHazard)
  list(survival = survival)
}


nr_obs = 50000

Time_input <- rnorm(nr_obs, mean = 100, sd = 36)
Status_input <- sample(c(0,1), replace=TRUE, size=nr_obs)
w_input <- Status_input

# Let's suppose there are 9 variables (first column denotes the intercept)
n_variables <- 9
X_input <- matrix(rnorm(nr_obs*n_variables),nr_obs)
X_input <- cbind(Intercept = rep(1, nrow(X_input)), X_input) 

beta_input <- runif(n_variables, min = -1, max = 1)
model_input <- "ph"

system.time(
  output <- smsurv(Time_input,Status_input,X_input,beta_input,w_input,model_input)
)
#>    user  system elapsed 
#>   25.08    4.95   33.39

system.time(
  output2 <- smsurv2(Time_input,Status_input,X_input,beta_input,w_input,model_input)
)
#>    user  system elapsed 
#>   16.92    1.45   19.86

identical(output, output2)
#> [1] TRUE
Run Code Online (Sandbox Code Playgroud)

创建于 2023-09-30,使用reprex v2.0.2

  • 这个答案已从元链接:https://meta.stackoverflow.com/q/426696。可能会解释为什么您对一个非常好的答案投了反对票。 (8认同)

jbl*_*d94 7

矢量化可实现 >600 倍的加速

cumsum这里的要点是在排序后使用Time以避免循环。然后一切都被矢量化,并且几乎立即运行。

由于这个问题得到了如此多的关注,我添加了一个更长的解释来解释为什么for这里的循环如此缓慢。滚动到底部即可。

使用data.table

library(data.table)

smsurv2 <- function(Time, Status, X, beta, w, model){
  list(
    survival = setorder(
      setorder(
        data.table(Time, Status)[,r := .I],
        -Time, Status
      )[
        ,`:=`(
          death_point = Time != shift(Time, -1L, Time[1] - 1) & Status,
          temp = cumsum((if (model=='ph') w*drop(exp(tcrossprod(beta, X[,-1]))) else w)[r]),
          lambda = cumsum(Status)
        )
      ][
        death_point == TRUE, lambda := c(lambda[1], diff(lambda))/temp
      ][
        death_point == FALSE, lambda := 0
      ][
        ,survival := exp(-rev(cumsum(rev(lambda))))
      ][
        seq_len(which.max(death_point) - 1L), survival := 0
      ][, c(1:2, 4:6) := NULL],
      r
    )[[2]]
  )
}
Run Code Online (Sandbox Code Playgroud)

使用 OP 的示例数据比较时间:

system.time({
  output <- smsurv(Time_input,Status_input,X_input,beta_input,w_input,model_input)
})
#>    user  system elapsed 
#>   36.19    4.97   42.89
system.time({
  output2 <- smsurv2(Time_input,Status_input,X_input,beta_input,w_input,model_input)
})
#>    user  system elapsed 
#>    0.07    0.00    0.07

all.equal(output, output2)
#> [1] TRUE
Run Code Online (Sandbox Code Playgroud)

避免for循环

考虑第一个循环中的第一行for

if(model=='ph')  temp <- sum(as.numeric(Time>=death_point[i])*w*drop(coxexp))
Run Code Online (Sandbox Code Playgroud)

这里,death_point是取自 的值向量Timew*coxexp是与 长度相同 (50000) 的值向量Time。该行只是简单地求和w*coxexp,其中对应的值Time大于或等于death_point[i]。这涉及 50000 次比较,然后进行两次(向量化)乘法,并对结果求和。如果death_point长度为 1,那么这样做是有意义的。然而,death_point在例子中长度是24920,所以这一行涉及50000*24920~12.5亿次比较!相反,我们实际上可以通过一次排序来完成所有这些比较:将Time和 的结果w*coxexp放入表中,按 降序排序Time,计算 的累积和w*coxexp,然后获取标记为death_points 的时间的值。我们现在拥有了所有需要的值temp,并且只需要两个向量化函数调用(setordercumsum)。所有其他计算都smsurv2可以轻松执行,因为表已经排序。