我正在尝试为矩阵的每一列取累积和.这是我在R中的代码:
testMatrix = matrix(1:65536, ncol=256);
microbenchmark(apply(testMatrix, 2, cumsum), times=100L);
Unit: milliseconds
expr min lq mean median uq max neval
apply(testMatrix, 2, cumsum) 1.599051 1.766112 2.329932 2.15326 2.221538 93.84911 10000
Run Code Online (Sandbox Code Playgroud)
我用Rcpp进行比较:
cppFunction('NumericMatrix apply_cumsum_col(NumericMatrix m) {
for (int j = 0; j < m.ncol(); ++j) {
for (int i = 1; i < m.nrow(); ++i) {
m(i, j) += m(i - 1, j);
}
}
return m;
}');
microbenchmark(apply_cumsum_col(testMatrix), times=10000L);
Unit: microseconds
expr min lq mean median uq max neval
apply_cumsum_col(testMatrix) 205.833 257.719 309.9949 265.986 276.534 96398.93 10000
Run Code Online (Sandbox Code Playgroud)
所以C++代码的速度要快7.5倍.是否有可能比apply(testMatrix, 2, cumsum)纯R 更好?感觉就像没有任何理由我有一个数量级的开销.
用R代码很难击败C++.我能想到的最快的方法就是你愿意把你的矩阵分成一个列表.这样,R使用原始函数,并且不会在每次迭代时复制对象(apply本质上是一个漂亮的循环).您可以看到C++仍然胜出但list如果您真的只想使用R代码,那么该方法会有显着的加速.
fun1 <- function(){
apply(testMatrix, 2, cumsum)
}
testList <- split(testMatrix, col(testMatrix))
fun2 <- function(){
lapply(testList, cumsum)
}
microbenchmark(fun1(),
fun2(),
apply_cumsum_col(testMatrix),
times=100L)
Unit: microseconds
expr min lq mean median uq max neval
fun1() 3298.534 3411.9910 4376.4544 3477.608 3699.2485 9249.919 100
fun2() 558.800 596.0605 766.2377 630.841 659.3015 5153.100 100
apply_cumsum_col(testMatrix) 219.651 282.8570 576.9958 311.562 339.5680 4915.290 100
Run Code Online (Sandbox Code Playgroud)
编辑
请注意,此方法比fun1包含将矩阵拆分为列表的时间要慢.
使用字节编译的for循环比apply我的系统上的调用稍快.我期望它更快,因为它的工作量少于apply.正如所料,R循环仍然比您编写的简单C++函数慢.
colCumsum <- compiler::cmpfun(function(x) {
for (i in 1:ncol(x))
x[,i] <- cumsum(x[,i])
x
})
testMatrix <- matrix(1:65536, ncol=256)
m <- testMatrix
require(microbenchmark)
microbenchmark(colCumsum(m), apply_cumsum_col(m), apply(m, 2, cumsum), times=100L)
# Unit: microseconds
# expr min lq median uq max neval
# matrixCumsum(m) 1478.671 1540.5945 1586.1185 2199.9530 37377.114 100
# apply_cumsum_col(m) 178.214 192.4375 204.3905 234.8245 1616.030 100
# apply(m, 2, cumsum) 1879.850 1940.1615 1991.3125 2745.8975 4346.802 100
all.equal(colCumsum(m), apply(m, 2, cumsum))
# [1] TRUE
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1406 次 |
| 最近记录: |