Cli*_* AB 29 performance r matrix-multiplication rcpp
在R中,矩阵乘法非常优化,即实际上只是对BLAS/LAPACK的调用.但是,我很惊讶这个非常天真的C++代码用于矩阵向量乘法似乎可靠地快了30%.
library(Rcpp)
# Simple C++ code for matrix multiplication
mm_code =
"NumericVector my_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
double v_j;
for(int j = 0; j < nCol; j++){
v_j = v[j];
for(int i = 0; i < nRow; i++){
ans[i] += m(i,j) * v_j;
}
}
return(ans);
}
"
# Compiling
my_mm = cppFunction(code = mm_code)
# Simulating data to use
nRow = 10^4
nCol = 10^4
m = matrix(rnorm(nRow * nCol), nrow = nRow)
v = rnorm(nCol)
system.time(my_ans <- my_mm(m, v))
#> user system elapsed
#> 0.103 0.001 0.103
system.time(r_ans <- m %*% v)
#> user system elapsed
#> 0.154 0.001 0.154
# Double checking answer is correct
max(abs(my_ans - r_ans))
#> [1] 0
Run Code Online (Sandbox Code Playgroud)
基础R是否%*%执行某些类型的数据检查我正在跳过?
编辑:
在了解了正在发生的事情后(感谢SO!),值得注意的是,这是R的最坏情况%*%,即矢量矩阵.例如,@ RalfStubner指出使用RcppArmadillo比使用简单的实现更快,但对于矩阵 - 矩阵乘法(当两个矩阵都是大和正方形时)几乎相同:
arma_code <-
"arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
return m * m2;
};"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
nRow = 10^3
nCol = 10^3
mat1 = matrix(rnorm(nRow * nCol),
nrow = nRow)
mat2 = matrix(rnorm(nRow * nCol),
nrow = nRow)
system.time(arma_mm(mat1, mat2))
#> user system elapsed
#> 0.798 0.008 0.814
system.time(mat1 %*% mat2)
#> user system elapsed
#> 0.807 0.005 0.822
Run Code Online (Sandbox Code Playgroud)
所以R的电流(v3.5.0)%*%对于矩阵矩阵来说是接近最优的,但如果你可以跳过检查,可以显着加快矩阵向量的速度.
Jos*_*ien 27
快速浏览names.c(特别是这里)指向您do_matprod,调用的C函数%*%以及在文件中找到的函数array.c.(有趣的是,事实证明,这两个crossprod和tcrossprod调度到相同的功能的孔).这是代码的链接do_matprod.
滚动浏览该函数,您可以看到它会处理您的天真实现所没有的一些事情,包括:
%*%属于已提供此类方法的类时,允许调度到备用S4方法.(这就是函数的这一部分发生的事情.)在函数末尾附近,它将调度到matprod或者或者cmatprod.有趣的是(至少对我而言),在真实矩阵的情况下,如果矩阵可能包含NaN或Inf值,则将matprod(在此处)调度到一个函数,该函数simple_matprod与您自己的函数一样简单明了.否则,它会调度到几个BLAS Fortran例程中的一个,如果可以保证统一的"良好行为"矩阵元素,可能会更快.
Josh的答案解释了为什么R的矩阵乘法不如这种天真的方法快.我很想知道使用RcppArmadillo可以获得多少收益.代码很简单:
arma_code <-
"arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
return m * v;
};"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
Run Code Online (Sandbox Code Playgroud)
基准测试:
> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
expr min lq mean median uq max neval
my_mm(m, v) 71.23347 75.22364 90.13766 96.88279 98.07348 98.50182 10
m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751 10
arma_mm(m, v) 41.13348 41.42314 41.89311 41.81979 42.39311 42.78396 10
Run Code Online (Sandbox Code Playgroud)
所以RcppArmadillo为我们提供了更好的语法和更好的性能.
好奇心让我变得更好.这里是直接使用BLAS的解决方案:
blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
char trans = 'N';
double one = 1.0, zero = 0.0;
int ione = 1;
F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
&ione, &zero, ans.begin(), &ione);
return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")
Run Code Online (Sandbox Code Playgroud)
基准测试:
Unit: milliseconds
expr min lq mean median uq max neval
my_mm(m, v) 72.61298 75.40050 89.75529 96.04413 96.59283 98.29938 10
m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572 10
arma_mm(m, v) 41.06718 41.70331 42.62366 42.47320 43.22625 45.19704 10
blas_mm(m, v) 41.58618 42.14718 42.89853 42.68584 43.39182 44.46577 10
Run Code Online (Sandbox Code Playgroud)
犰狳和BLAS(在我的情况下是OpenBLAS)几乎是一样的.而BLAS代码也是R最终所做的.所以R的2/3是错误检查等.
| 归档时间: |
|
| 查看次数: |
1452 次 |
| 最近记录: |