Was*_*abi 4 r data.table rolling-computation
我有一个包含两列a和的 data.frame b,在哪里a排序。我想要得到的滚动平均b,其中窗口的范围a - 5,以a(即从当前值a到哪里a - 5是)。
使用不同的窗口宽度执行滚动平均值是微不足道的data.table::frollmean()(adaptive = TRUE;“每个单独的观察都有自己对应的滚动窗口宽度”),所以唯一的问题是计算这些窗口宽度。
那么,给定以下 data.frame,如何确定每个均值的窗口大小?
set.seed(42)
x <- data.frame(
a = sort(runif(10, 0, 10)),
b = 1:10
)
x
#> a b
#> 1 1.346666 1
#> 2 2.861395 2
#> 3 5.190959 3
#> 4 6.417455 4
#> 5 6.569923 5
#> 6 7.050648 6
#> 7 7.365883 7
#> 8 8.304476 8
#> 9 9.148060 9
#> 10 9.370754 10
Run Code Online (Sandbox Code Playgroud)
由reprex 包(v0.3.0)于 2020 年 7 月 3 日创建
如果我将窗口大小作为一个新列n,我希望结果是
#> a b n
#> 1 1.346666 1 1
#> 2 2.861395 2 2
#> 3 5.190959 3 3
#> 4 6.417455 4 3
#> 5 6.569923 5 4
#> 6 7.050648 6 5
#> 7 7.365883 7 6
#> 8 8.304476 8 6
#> 9 9.148060 9 7
#> 10 9.370754 10 8
Run Code Online (Sandbox Code Playgroud)
因此,举例来说,有之间的两个值a[2] = 2.86和2.86 - 5(包括其本身),并有之间六个值a[8] = 8.30和8.30 - 5。
我已经设法做到这一点outer:
suppressPackageStartupMessages({
library(magrittr)
library(data.table)
})
f <- function(x, y) {
return(y %between% list(x - 5, x))
}
outer(x$a, x$a, f) %>% rowSums()
#> [1] 1 2 3 3 4 5 6 6 7 8
Run Code Online (Sandbox Code Playgroud)
然而,我的真实案例很容易有 5000 行,而且这种方法变得很慢(大约需要 10 秒)。我看到的一个问题是它将 的每个值a与 的每个其他值a进行比较,因此必须执行大约 25,000,000 次比较。但是,我知道a是排序的,所以如果我们TRUE在比较中找到一段结果然后 a FALSE,我们知道当前值的所有后续结果a也将是FALSE(这意味着我们在允许的范围内,然后移过了的最高允许值a,因此其他所有内容也将被拒绝)。
那么,有没有更好、更快的方法来做到这一点?
这是一种以非等价自连接聚合的替代方法:
library(data.table)
setDT(x)[, low := a - 5][
, n := x[x, on = .(a >= low , a <= a), by = .EACHI, .N]$N][
, low := NULL][]
Run Code Online (Sandbox Code Playgroud)
Run Code Online (Sandbox Code Playgroud)a b n 1: 1.346666 1 1 2: 2.861395 2 2 3: 5.190959 3 3 4: 6.417455 4 3 5: 6.569923 5 4 6: 7.050648 6 5 7: 7.365883 7 6 8: 8.304476 8 6 9: 9.148060 9 7 10: 9.370754 10 8
但OP的目标是计算具有可变窗口大小的滚动平均值。
frollmean()那么,当我们可以一次性得到它时,为什么要停下来打电话呢?:
library(data.table)
setDT(x)[, low := a - 5][
, roll.mean := x[x, on = .(a >= low , a <= a), by = .EACHI, mean(b)]$V1][
, low := NULL][]
Run Code Online (Sandbox Code Playgroud)
Run Code Online (Sandbox Code Playgroud)a b roll.mean 1: 1.346666 1 1.0 2: 2.861395 2 1.5 3: 5.190959 3 2.0 4: 6.417455 4 3.0 5: 6.569923 5 3.5 6: 7.050648 6 4.0 7: 7.365883 7 4.5 8: 8.304476 8 5.5 9: 9.148060 9 6.0 10: 9.370754 10 6.5
由于 OP 关心其生产用例的性能,这里有一个基准,它会改变行数以及窗口大小:
library(bench)
library(ggplot2)
bm <- press(
n = 10^(c(2, 3, 4)),
window_size = c(5, 15, 50),
{
set.seed(42)
x0 <- data.table(
a = sort(runif(n, 0, n)),
b = seq(n)
)
mark(
findInterval = {
x <- copy(x0)
x[, roll.mean := frollmean(b, .I - findInterval(a - window_size, a), adaptive = TRUE)]
},
non_equi_join = {
x <- copy(x0)
x[, low := a - window_size][
, roll.mean := x[x, on = .(a >= low , a <= a), by = .EACHI, mean(b)]$V1][
, low := NULL]
}
)
}
)
autoplot(bm)
Run Code Online (Sandbox Code Playgroud)
显然,
findInterval()方法与自适应方法的组合frollmean()总是比非等值连接方法快一个数量级以上因为您似乎data.table无论如何都会加载(对于frollmean),您可以强制您data.frame到data.table,并通过引用添加新列。
findInterval用于在原始值中找到每个减去值的索引。然后从原始索引中减去该索引,通过.I或获得seq_along,以获得窗口大小。
setDT(x)
x[ , n := .I - findInterval(a - 5, a)]
# x
# a b n
# 1: 1.346666 1 1
# 2: 2.861395 2 2
# 3: 5.190959 3 3
# 4: 6.417455 4 3
# 5: 6.569923 5 4
# 6: 7.050648 6 5
# 7: 7.365883 7 6
# 8: 8.304476 8 6
# 9: 9.148060 9 7
# 10: 9.370754 10 8
Run Code Online (Sandbox Code Playgroud)
类似于base:
x$n = seq_along(x$a) - findInterval(x$a - 5, x$a)
Run Code Online (Sandbox Code Playgroud)