更快地评估从右到左的矩阵乘法

Tao*_*Tan 14 performance r matrix blas matrix-multiplication

我注意到在 R 中从右到左以二次形式计算矩阵运算明显快于从左到右,具体取决于括号的放置方式。显然它们都执行相同的计算量。我想知道为什么会这样。这与内存分配有什么关系吗?

# A: 5000 * 5000
# B: 5000 * 2
A = matrix(runif(5000 * 5000), nrow = 5000)
B = matrix(rbinom(5000 * 2, size = 2, prob = 0.3), nrow = 5000)

microbenchmark((t(B) %*% A) %*% B, t(B) %*% (A %*% B), times = 100)
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

这是会话信息:

R version 4.2.0 (2022-04-22)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Big Sur 11.4

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] Rcpp_1.0.9           microbenchmark_1.4.9

loaded via a namespace (and not attached):
 [1] compiler_4.2.0           fastmap_1.1.0            cli_3.3.0                htmltools_0.5.3          tools_4.2.0             
 [6] RcppArmadillo_0.11.2.4.0 rstudioapi_0.13          yaml_2.3.5               rmarkdown_2.14           knitr_1.39              
[11] xfun_0.31                digest_0.6.29            rlang_1.0.4              evaluate_0.15           
Run Code Online (Sandbox Code Playgroud)

编辑:矩阵乘法的简化版本显示相同的错误。

# A: 5000 * 5000
# B: 5000 * 2
A = matrix(runif(5000 * 5000), nrow = 5000)
B = matrix(rbinom(5000 * 2, size = 2, prob = 0.3), nrow = 5000)

microbenchmark((t(B) %*% A) %*% B, t(B) %*% (A %*% B), times = 100)
Run Code Online (Sandbox Code Playgroud)

GKi*_*GKi 4

看起来这是一个实现以及元素访问顺序的问题。请参阅迭代 2D 数组时为什么循环顺序会影响性能?

\n

当做同样的事情但改变循环的顺序时,时间是不同的。

\n
Rcpp::cppFunction("NumericMatrix mm(const NumericMatrix& A, const NumericMatrix& B) {\nint M = A.nrow();\nint N = A.ncol();\nint P = B.ncol();\nNumericMatrix res(M, P);\nfor (int n=0; n<N; ++n) {  //Loop n, p, m\n  for (int p=0; p<P; ++p) {\n    for (int m=0; m<M; ++m) {\n      res[m+p*M] += A[m+M*n] * B[p*N+n];\n    }\n  }\n}\nreturn res;}")\n\nRcpp::cppFunction("NumericMatrix mm2(const NumericMatrix& A, const NumericMatrix& B) {\nint M = A.nrow();\nint N = A.ncol();\nint P = B.ncol();\nNumericMatrix res(M, P);\nfor (int m=0; m<M; ++m) {  //Loop m, p, n\n  for (int p=0; p<P; ++p) {\n     for (int n=0; n<N; ++n) {\n      res[m+p*M] += A[m+M*n] * B[p*N+n];\n    }\n  }\n}\nreturn res;}")\n
Run Code Online (Sandbox Code Playgroud)\n
k <- 5000L; m <- n <- 2L;\nA <- matrix(rnorm(k * k), k, k);\nB <- matrix(rnorm(k * n), k, n);\ntB <- t(B);\n\nmet <- alist("(tB*A)B"     = tB %*% A %*% B,\n             "tB(A*B)"     = tB %*% (A %*% B),\n             "mm (tB*A)B"  = mm(mm(tB, A), B),\n             "mm tB(A*B)"  = mm(tB, mm(A, B)),\n             "mm2 (tB*A)B" = mm2(mm2(tB, A), B),\n             "mm2 tB(A*B)" = mm2(tB, mm2(A, B)),\n             "cp(B,A)B"    = crossprod(B, A) %*% B,\n             "cp(B,A*B)"   = crossprod(B, A %*% B) )\n
Run Code Online (Sandbox Code Playgroud)\n
bench::mark(exprs = met)\n#  expression     min median itr/s\xe2\x80\xa6\xc2\xb9 mem_a\xe2\x80\xa6\xc2\xb2 gc/se\xe2\x80\xa6\xc2\xb3 n_itr  n_gc total\xe2\x80\xa6\xe2\x81\xb4 result  \n#  <bch:expr>  <bch:> <bch:>   <dbl> <bch:b>   <dbl> <int> <dbl> <bch:t> <list>  \n#1 (tB*A)B     79.5ms 80.1ms    12.5 78.17KB       0     7     0   562ms <dbl[\xe2\x80\xa6]>\n#2 tB(A*B)     33.8ms 34.4ms    28.4 78.17KB       0    15     0   528ms <dbl[\xe2\x80\xa6]>\n#3 mm (tB*A)B  61.9ms 62.5ms    15.9  3.85MB       0     8     0   502ms <dbl[\xe2\x80\xa6]>\n#4 mm tB(A*B)  19.9ms 20.7ms    48.1 83.16KB       0    25     0   520ms <dbl[\xe2\x80\xa6]>\n#5 mm2 (tB*A)B 35.9ms 39.4ms    25.8 87.29KB       0    13     0   504ms <dbl[\xe2\x80\xa6]>\n#6 mm2 tB(A*B) 47.8ms 48.1ms    20.6 83.16KB       0    11     0   535ms <dbl[\xe2\x80\xa6]>\n#7 cp(B,A)B    44.1ms 44.5ms    22.4 80.42KB       0    12     0   536ms <dbl[\xe2\x80\xa6]>\n#8 cp(B,A*B)   34.1ms 36.5ms    27.1 78.17KB       0    14     0   516ms <dbl[\xe2\x80\xa6]>\n\nmicrobenchmark::microbenchmark(list = met)\n#Unit: milliseconds\n#        expr      min       lq     mean   median       uq      max neval\n#     (tB*A)B 77.09484 77.86891 79.09483 78.44832 80.08971 87.05563   100\n#     tB(A*B) 33.63306 34.22562 36.08482 35.14064 36.64080 51.39962   100\n#  mm (tB*A)B 62.05235 64.14361 66.54568 65.16927 67.98617 75.96242   100\n#  mm tB(A*B) 19.67066 20.28369 20.83781 20.53820 21.19940 23.64119   100\n# mm2 (tB*A)B 35.31290 35.70006 36.62846 36.10282 37.41669 40.47473   100\n# mm2 tB(A*B) 48.16574 49.70702 51.55844 50.26292 52.46479 67.44558   100\n#    cp(B,A)B 43.18166 44.01366 45.28434 44.71301 46.41521 48.97891   100\n#   cp(B,A*B) 33.62158 34.47070 35.84743 35.11853 36.55979 48.89021   100\n
Run Code Online (Sandbox Code Playgroud)\n