假设我有一个数据框如下:
> foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
> foo
x id
1 1 1
2 2 1
3 3 2
4 4 2
5 5 2
6 6 3
7 7 3
8 8 3
9 9 3
Run Code Online (Sandbox Code Playgroud)
我想要一个非常有效的h(a,b)实现,它计算xi的总和(a - xi)*(b - xj),xj属于同一个id类.例如,我当前的实现是
h(a, b, foo){
a.diff = a - foo$x
b.diff = b - foo$x
prod = a.diff%*%t(b.diff)
id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) + diag(nrow(foo))
return(sum(prod*id.indicator))
}
Run Code Online (Sandbox Code Playgroud)
例如,使用(a,b)=(0,1),这是函数中每个步骤的输出
> a.diff
[1] -1 -2 -3 -4 -5 -6 -7 -8 -9
> b.diff
[1] 0 -1 -2 -3 -4 -5 -6 -7 -8
> prod
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9]
[1,] 0 1 2 3 4 5 6 7 8
[2,] 0 2 4 6 8 10 12 14 16
[3,] 0 3 6 9 12 15 18 21 24
[4,] 0 4 8 12 16 20 24 28 32
[5,] 0 5 10 15 20 25 30 35 40
[6,] 0 6 12 18 24 30 36 42 48
[7,] 0 7 14 21 28 35 42 49 56
[8,] 0 8 16 24 32 40 48 56 64
[9,] 0 9 18 27 36 45 54 63 72
> id.indicator
1 2 3 4 5 6 7 8 9
1 1 1 0 0 0 0 0 0 0
2 1 1 0 0 0 0 0 0 0
3 0 0 1 1 1 0 0 0 0
4 0 0 1 1 1 0 0 0 0
5 0 0 1 1 1 0 0 0 0
6 0 0 0 0 0 1 1 1 1
7 0 0 0 0 0 1 1 1 1
8 0 0 0 0 0 1 1 1 1
9 0 0 0 0 0 1 1 1 1
Run Code Online (Sandbox Code Playgroud)
实际上,最多可以有1000个id簇,每个簇至少有40个,这使得这个方法效率太低,因为id.indicator中的稀疏条目和在块外对角线上的额外计算"赢了"使用.
我玩了一下.首先,你的实施:
foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
h <- function(a, b, foo){
a.diff = a - foo$x
b.diff = b - foo$x
prod = a.diff%*%t(b.diff)
id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) +
diag(nrow(foo))
return(sum(prod*id.indicator))
}
h(a = 1, b = 0, foo = foo)
#[1] 891
Run Code Online (Sandbox Code Playgroud)
接下来,我尝试使用适当的稀疏矩阵实现(通过Matrix包)和索引矩阵的函数的变体.我也tcrossprod经常使用它比使用它快一点a %*% t(b).
library("Matrix")
h2 <- function(a, b, foo) {
a.diff <- a - foo$x
b.diff <- b - foo$x
prod <- tcrossprod(a.diff, b.diff) # the same as a.diff%*%t(b.diff)
id.indicator <- do.call(bdiag, lapply(table(foo$id), function(n) matrix(1,n,n)))
return(sum(prod*id.indicator))
}
h2(a = 1, b = 0, foo = foo)
#[1] 891
Run Code Online (Sandbox Code Playgroud)
请注意,此函数依赖于foo$id排序.
最后,我尝试避免创建完整的n×n矩阵.
h3 <- function(a, b, foo) {
a.diff <- a - foo$x
b.diff <- b - foo$x
ids <- unique(foo$id)
res <- 0
for (i in seq_along(ids)) {
indx <- which(foo$id == ids[i])
res <- res + sum(tcrossprod(a.diff[indx], b.diff[indx]))
}
return(res)
}
h3(a = 1, b = 0, foo = foo)
#[1] 891
Run Code Online (Sandbox Code Playgroud)
对您的示例进行基准测试:
library("microbenchmark")
microbenchmark(h(a = 1, b = 0, foo = foo),
h2(a = 1, b = 0, foo = foo),
h3(a = 1, b = 0, foo = foo))
# Unit: microseconds
# expr min lq mean median uq max neval
# h(a = 1, b = 0, foo = foo) 248.569 261.9530 493.2326 279.3530 298.2825 21267.890 100
# h2(a = 1, b = 0, foo = foo) 4793.546 4893.3550 5244.7925 5051.2915 5386.2855 8375.607 100
# h3(a = 1, b = 0, foo = foo) 213.386 227.1535 243.1576 234.6105 248.3775 334.612 100
Run Code Online (Sandbox Code Playgroud)
现在,在这个例子中,它h3是最快的,而且h2非常慢.但我想对于更大的例子来说两者都会更快.也许,h3仍然会赢得更大的例子.虽然有足够的空间进行更多优化,但h3应该更快,更高效.所以,我认为你应该选择一种h3不会产生不必要的大型矩阵的变体.
tapply允许您跨向量组应用函数,并在可能的情况下将结果简化为矩阵或向量。用于tcrossprod将每个组的所有组合相乘,并且在一些适当大的数据上它表现良好:
# setup
set.seed(47)
foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
foo2 <- data.frame(id = sample(1000, 40000, TRUE), x = rnorm(40000))
h_OP <- function(a, b, foo){
a.diff = a - foo$x
b.diff = b - foo$x
prod = a.diff %*% t(b.diff)
id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) + diag(nrow(foo))
return(sum(prod * id.indicator))
}
h3_AEBilgrau <- function(a, b, foo) {
a.diff <- a - foo$x
b.diff <- b - foo$x
ids <- unique(foo$id)
res <- 0
for (i in seq_along(ids)) {
indx <- which(foo$id == ids[i])
res <- res + sum(tcrossprod(a.diff[indx], b.diff[indx]))
}
return(res)
}
h_d.b <- function(a, b, foo){
sum(sapply(split(foo, foo$id), function(d) sum(outer(a-d$x, b-d$x))))
}
h_alistaire <- function(a, b, foo){
sum(tapply(foo$x, foo$id, function(x){sum(tcrossprod(a - x, b - x))}))
}
Run Code Online (Sandbox Code Playgroud)
所有返回相同的东西,并且在小数据上没有那么不同:
h_OP(0, 1, foo)
#> [1] 891
h3_AEBilgrau(0, 1, foo)
#> [1] 891
h_d.b(0, 1, foo)
#> [1] 891
h_alistaire(0, 1, foo)
#> [1] 891
# small data test
microbenchmark::microbenchmark(
h_OP(0, 1, foo),
h3_AEBilgrau(0, 1, foo),
h_d.b(0, 1, foo),
h_alistaire(0, 1, foo)
)
#> Unit: microseconds
#> expr min lq mean median uq max neval cld
#> h_OP(0, 1, foo) 143.749 157.8895 189.5092 189.7235 214.3115 262.258 100 b
#> h3_AEBilgrau(0, 1, foo) 80.970 93.8195 112.0045 106.9285 125.9835 225.855 100 a
#> h_d.b(0, 1, foo) 355.084 381.0385 467.3812 437.5135 516.8630 2056.972 100 c
#> h_alistaire(0, 1, foo) 148.735 165.1360 194.7361 189.9140 216.7810 287.990 100 b
Run Code Online (Sandbox Code Playgroud)
然而,在更大的数据上,差异变得更加明显。最初的版本可能会让我的笔记本电脑崩溃,但以下是最快的两台笔记本电脑的基准测试:
# on 1k groups, 40k rows
microbenchmark::microbenchmark(
h3_AEBilgrau(0, 1, foo2),
h_alistaire(0, 1, foo2)
)
#> Unit: milliseconds
#> expr min lq mean median uq max neval cld
#> h3_AEBilgrau(0, 1, foo2) 336.98199 403.04104 412.06778 410.52391 423.33008 443.8286 100 b
#> h_alistaire(0, 1, foo2) 14.00472 16.25852 18.07865 17.22296 18.09425 96.9157 100 a
Run Code Online (Sandbox Code Playgroud)
另一种可能性是使用 data.frame 按组汇总,然后对适当的列求和。在基础 R 中,您可以使用 来执行此操作aggregate,但 dplyr 和 data.table 很受欢迎,可以通过更复杂的聚合使这种方法变得更简单。
aggregate比 慢tapply。dplyr 比 更快aggregate,但仍然更慢。data.table 是为速度而设计的,几乎与tapply.
library(dplyr)
library(data.table)
h_aggregate <- function(a, b, foo){sum(aggregate(x ~ id, foo, function(x){sum(tcrossprod(a - x, b - x))})$x)}
tidy_h <- function(a, b, foo){foo %>% group_by(id) %>% summarise(x = sum(tcrossprod(a - x, b - x))) %>% select(x) %>% sum()}
h_dt <- function(a, b, foo){setDT(foo)[, .(x = sum(tcrossprod(a - x, b - x))), by = id][, sum(x)]}
microbenchmark::microbenchmark(
h_alistaire(1, 0, foo2),
h_aggregate(1, 0, foo2),
tidy_h(1, 0, foo2),
h_dt(1, 0, foo2)
)
#> Unit: milliseconds
#> expr min lq mean median uq max neval cld
#> h_alistaire(1, 0, foo2) 13.30518 15.52003 18.64940 16.48818 18.13686 62.35675 100 a
#> h_aggregate(1, 0, foo2) 93.08401 96.61465 107.14391 99.16724 107.51852 143.16473 100 c
#> tidy_h(1, 0, foo2) 39.47244 42.22901 45.05550 43.94508 45.90303 90.91765 100 b
#> h_dt(1, 0, foo2) 13.31817 15.09805 17.27085 16.46967 17.51346 56.34200 100 a
Run Code Online (Sandbox Code Playgroud)