数百万个间隔的“dplyr:: Between”:如何使其更快

Mat*_*ers 4 performance for-loop r between dataframe

我有一个包含startValue(在 列X1)和endValue(在 列X2)的数据框,以及异常值列表。我需要计算有多少异常值属于这些特定的startValueendValue。这是一个最小的例子和我的尝试:

library("dplyr")

set.seed(2022)

dat = data.frame(matrix(sort(sample(1:500, 50)), ncol = 2, byrow = T))
dim(dat)
dat$Outliers = NA

outliers = sort(sample(1:500, 30))

for (i in 1:25){
  dat$Outliers[i] = length(which(between(outliers, dat$X1[i], dat$X2[i]) == T))
}
Run Code Online (Sandbox Code Playgroud)

这段代码工作正常。但我的真实数据有数百万行,这个for循环需要很多时间。有没有更快的方法来解决这个问题?

李哲源*_*李哲源 5

编码级别的浅层修复

通过以下修改,您将立即看到加速:

out <- integer(nrow(dat))
for (i in 1:nrow(dat)) {
  out[i] <- sum(between(outliers, dat$X1[i], dat$X2[i]))
}
dat$Outliers1 <- out
with(dat, identical(Outliers, Outliers1))  ## the result is correct
#[1] TRUE
Run Code Online (Sandbox Code Playgroud)
  1. 这可以避免在循环期间修改 data.frame,从而避免大量不必要的内存副本(如果您有一个大 data.frame,这一点尤其重要)。

  2. length(which(LogicalVec == TRUE))很尴尬;使用sum(LogicalVec)

您还可以尝试使用以下命令替换此循环:

dat$Outliers2 <- mapply(\(a, b) sum(between(outliers, a, b)), dat$X1, dat$X2)
with(dat, identical(Outliers, Outliers2))  ## the result is correct
#[1] TRUE
Run Code Online (Sandbox Code Playgroud)

Henrik提到了一个解决方案:

library(data.table)
setDT(dat)
d_vals <- as.data.table(outliers)
dat[ , Outliers3 := d_vals[dat, on = .(outliers >= X1, outliers <= X2), .N, by = .EACHI]$N]
with(dat, identical(Outliers, Outliers3))
#[1] TRUE
Run Code Online (Sandbox Code Playgroud)

算法层面的深度修复

如果您的真实数据与您模拟的数据类似,且间隔不重叠[startValue, endValue],则可以使用base::findInterval而不是dplyr::between。唯一的问题是findInterval使用区间[a, b)between使用[a, b]。因此,如果异常值恰好等于 endValue 则会between对其计数两次,而findInterval仅对其计数一次。这确实发生在你最小的例子中。

vec <- c(with(dat, rbind(X1, X2)))
if (is.unsorted(vec, strictly = TRUE)) stop("cannot apply this method!")
x <- findInterval(outliers, vec, rightmost.closed = TRUE)
x <- x[x %% 2 == 1]  ## get rid of X2[i] ~ X1[i + 1]; just keep X1[i] ~ X2[i]
x <- (x + 1) / 2
dat$Outliers4 <- tabulate(x, nrow(dat))
with(dat, identical(Outliers, Outliers4))  ## oops, the results are not identical
#[1] FALSE

## you see, the first interval is causing problem
with(dat, Outliers - Outliers4)
#[1] 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

## inspect the first interval
dat[1, 1:2]
#  X1 X2
#1  1  3

## inspect outliers: 3 is an outlier value
## it is not in [1, 3), so not counted by `findInterval`
## but it is in [1, 3], so counted by `between`
outliers
# [1]   1   3  19  32  42  56  87  99 120 218 234 248 276 287 306 316 319 336 351
#[20] 353 361 382 389 400 417 423 425 458 467 480
Run Code Online (Sandbox Code Playgroud)

好吧,取决于您最终想要什么,您可能会也可能不会解决这个差异。如果你确实想修复,你可以这样做:

i <- which(dat$X2 %in% outliers)
dat$Outliers4[i] <- dat$Outliers4[i] + 1L
with(dat, identical(Outliers, Outliers4))
#[1] TRUE
Run Code Online (Sandbox Code Playgroud)

等等,这假设您outliers没有重复的值。如果某个值出现多次,则需要加上重复项的数量而不是 1。您需要:

i <- which(dat$X2 %in% outliers)
m <- colSums(matrix(outliers %in% dat$X2[i], ncol = length(i)))
dat$Outliers4[i] <- dat$Outliers4[i] + m
Run Code Online (Sandbox Code Playgroud)

基准

## a 1000 x 2 data.frame
set.seed(2022)
dat = data.frame(matrix(sort(sample(20000, 2000)), ncol = 2, byrow = T))
outliers = sort(sample(20000, 1000))
Run Code Online (Sandbox Code Playgroud)
## your current method
method0 <- function (dat, outliers) {
  dat$Outliers = NA
  for (i in 1:nrow(dat)) {
    dat$Outliers[i] = length(which(between(outliers, dat$X1[i], dat$X2[i]) == T))
  }
  dat
}
Run Code Online (Sandbox Code Playgroud)
## modified `for`-loop method
method1 <- function (dat, outliers) {
  out <- integer(nrow(dat))
  for (i in 1:nrow(dat)) {
    out[i] <- sum(between(outliers, dat$X1[i], dat$X2[i]))
  }
  dat$Outliers <- out
  dat
}

## `mapply` method
method2 <- function (dat, outliers) {
  dat$Outliers <- mapply(\(a, b) sum(between(outliers, a, b)), dat$X1, dat$X2)
  dat
}
Run Code Online (Sandbox Code Playgroud)
## "data.table" method
method3 <- function (dat, outliers) {
  DT <- as.data.table(dat)
  d_vals <- as.data.table(outliers)
  DT[ , Outliers := d_vals[DT, on = .(outliers >= X1, outliers <= X2), .N, by = .EACHI]$N]
  DT
}
Run Code Online (Sandbox Code Playgroud)
## fineInterval() + tabulate(), with optional discrepancy fix
method4 <- function (dat, outliers, fix = FALSE) {
  vec <- c(with(dat, rbind(X1, X2)))
  if (is.unsorted(vec, strictly = TRUE)) stop("cannot apply this method!")
  x <- findInterval(outliers, vec, rightmost.closed = TRUE)
  x <- x[x %% 2 == 1]
  x <- (x + 1) / 2
  out <- tabulate(x, nrow(dat))
  if (fix) {
    i <- which(dat$X2 %in% outliers)
    if (anyDuplicated(outliers)) {
      m <- colSums(matrix(outliers %in% dat$X2[i], ncol = length(i)))
    } else {
      m <- 1L
    }
    out[i] <- out[i] + m
  }
  dat$Outliers <- out
  dat
}
Run Code Online (Sandbox Code Playgroud)
library(microbenchmark)
microbenchmark(method0 = method0(dat, outliers),
               method1 = method1(dat, outliers),
               method2 = method2(dat, outliers),
               method3 = method3(dat, outliers),
               method4F = method4(dat, outliers, fix = FALSE),
               method4T = method4(dat, outliers, fix = TRUE))
#Unit: microseconds
#     expr       min         lq       mean    median        uq       max neval
#  method0 40459.251 40854.4165 44121.8656 41015.071 46436.530 171540.17   100
#  method1 20755.578 20905.8405 21740.9797 21032.032 21309.445  30860.18   100
#  method2 19320.867 19483.9885 20452.9461 19585.355 19808.503  26417.01   100
#  method3  4920.958  5171.5260  5411.6465  5340.284  5379.681  12849.05   100
# method4F   160.382   180.7245   260.6430   217.843   228.279   5548.48   100
# method4T   331.903   353.5260   612.7073   388.385   410.383  23536.62   100
Run Code Online (Sandbox Code Playgroud)

所以是编码层面最快的解决方案,大约快7~8倍。但如果间隔不重叠,那么改进的算法可提供 100 倍以上的加速!差异修复不是免费的,但仍然很快。


暗示

如果确实有重叠的间隔,仍然可以应用findInterval+方法。tabulate关键是首先使用不重叠的间隔,然后汇总计数。这是一个概念图。假设我们有两个重叠的区间:

[1, 3], [2, 5]
Run Code Online (Sandbox Code Playgroud)

我们可以首先将它们分离为不重叠的:

[1, 2), [2, 3), [3, 5]
Run Code Online (Sandbox Code Playgroud)

现在,findInterval+tabulate可以找到这些间隔的计数:

[1, 2) [2, 3) [3, 5]
N1 氮气 N3

期望的结果是一个聚合:

[1, 3) [2, 5]
N1+N2 N2+N3

请记住:没有什么比使用聪明的算法更快的了。