为什么这种天真的矩阵乘法比基数R更快?

Cli*_* AB 29 performance r matrix-multiplication rcpp

在R中,矩阵乘法非常优化,即实际上只是对BLAS/LAPACK的调用.但是,我很惊讶这个非常天真的C++代码用于矩阵向量乘法似乎可靠地快了30%.

 library(Rcpp)

 # Simple C++ code for matrix multiplication
 mm_code = 
 "NumericVector my_mm(NumericMatrix m, NumericVector v){
   int nRow = m.rows();
   int nCol = m.cols();
   NumericVector ans(nRow);
   double v_j;
   for(int j = 0; j < nCol; j++){
     v_j = v[j];
     for(int i = 0; i < nRow; i++){
       ans[i] += m(i,j) * v_j;
     }
   }
   return(ans);
 }
 "
 # Compiling
 my_mm = cppFunction(code = mm_code)

 # Simulating data to use
 nRow = 10^4
 nCol = 10^4

 m = matrix(rnorm(nRow * nCol), nrow = nRow)
 v = rnorm(nCol)

 system.time(my_ans <- my_mm(m, v))
#>    user  system elapsed 
#>   0.103   0.001   0.103 
 system.time(r_ans <- m %*% v)
#>   user  system elapsed 
#>  0.154   0.001   0.154 

 # Double checking answer is correct
 max(abs(my_ans - r_ans))
 #> [1] 0
Run Code Online (Sandbox Code Playgroud)

基础R是否%*%执行某些类型的数据检查我正在跳过?

编辑: 在了解了正在发生的事情后(感谢SO!),值得注意的是,这是R的最坏情况%*%,即矢量矩阵.例如,@ RalfStubner指出使用RcppArmadillo比使用简单的实现更快,但对于矩阵 - 矩阵乘法(当两个矩阵都是大和正方形时)几乎相同:

 arma_code <- 
   "arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
 return m * m2;
 };"
 arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

 nRow = 10^3 
 nCol = 10^3

 mat1 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)
 mat2 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)

 system.time(arma_mm(mat1, mat2))
#>   user  system elapsed 
#>   0.798   0.008   0.814 
 system.time(mat1 %*% mat2)
#>   user  system elapsed 
#>   0.807   0.005   0.822  
Run Code Online (Sandbox Code Playgroud)

所以R的电流(v3.5.0)%*%对于矩阵矩阵来说是接近最优的,但如果你可以跳过检查,可以显着加快矩阵向量的速度.

Jos*_*ien 27

快速浏览names.c(特别是这里)指向您do_matprod,调用的C函数%*%以及在文件中找到的函数array.c.(有趣的是,事实证明,这两个crossprodtcrossprod调度到相同的功能的孔).这是代码的链接do_matprod.

滚动浏览该函数,您可以看到它会处理您的天真实现所没有的一些事情,包括:

  1. 保留行名和列名,这是有意义的.
  2. 当通过调用操作的两个对象%*%属于已提供此类方法的类时,允许调度到备用S4方法.(这就是函数的这一部分发生的事情.)
  3. 处理实数和复数矩阵.
  4. 实现一系列规则,用于处理矩阵和矩阵的乘法,矢量和矩阵,矩阵和向量,以及向量和向量.(回想一下,在R中的交叉乘法中,LHS上的向量被视为行向量,而在RHS上,它被视为列向量;这是使得这样的代码.)

在函数末尾附近,它将调度到matprod或者或者cmatprod.有趣的是(至少对我而言),在真实矩阵的情况下,如果矩阵可能包含NaNInf值,则将matprod(在此处)调度到一个函数,该函数simple_matprod与您自己的函数一样简单明了.否则,它会调度到几个BLAS Fortran例程中的一个,如果可以保证统一的"良好行为"矩阵元素,可能会更快.

  • @CliffAB通过直接或间接通过RcppArmadillo和使用多线程BLAS使用适当的BLAS函数,您可能获得更多. (5认同)

Ral*_*ner 7

Josh的答案解释了为什么R的矩阵乘法不如这种天真的方法快.我很想知道使用RcppArmadillo可以获得多少收益.代码很简单:

arma_code <- 
  "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
       return m * v;
   };"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
Run Code Online (Sandbox Code Playgroud)

基准测试:

> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
       m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
 arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10
Run Code Online (Sandbox Code Playgroud)

所以RcppArmadillo为我们提供了更好的语法和更好的性能.

好奇心让我变得更好.这里是直接使用BLAS的解决方案:

blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")
Run Code Online (Sandbox Code Playgroud)

基准测试:

Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
       m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
 arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
 blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10
Run Code Online (Sandbox Code Playgroud)

犰狳和BLAS(在我的情况下是OpenBLAS)几乎是一样的.而BLAS代码也是R最终所做的.所以R的2/3是错误检查等.

  • 并且可能是OpenMP引导(假设您的OS /编译器支持它). (2认同)