Tao*_*Tan 14 performance r matrix blas matrix-multiplication
我注意到在 R 中从右到左以二次形式计算矩阵运算明显快于从左到右,具体取决于括号的放置方式。显然它们都执行相同的计算量。我想知道为什么会这样。这与内存分配有什么关系吗?
# A: 5000 * 5000
# B: 5000 * 2
A = matrix(runif(5000 * 5000), nrow = 5000)
B = matrix(rbinom(5000 * 2, size = 2, prob = 0.3), nrow = 5000)
microbenchmark((t(B) %*% A) %*% B, t(B) %*% (A %*% B), times = 100)
Run Code Online (Sandbox Code Playgroud)
这是会话信息:
R version 4.2.0 (2022-04-22)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Big Sur 11.4
Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] Rcpp_1.0.9 microbenchmark_1.4.9
loaded via a namespace (and not attached):
[1] compiler_4.2.0 fastmap_1.1.0 cli_3.3.0 htmltools_0.5.3 tools_4.2.0
[6] RcppArmadillo_0.11.2.4.0 rstudioapi_0.13 yaml_2.3.5 rmarkdown_2.14 knitr_1.39
[11] xfun_0.31 digest_0.6.29 rlang_1.0.4 evaluate_0.15
Run Code Online (Sandbox Code Playgroud)
编辑:矩阵乘法的简化版本显示相同的错误。
# A: 5000 * 5000
# B: 5000 * 2
A = matrix(runif(5000 * 5000), nrow = 5000)
B = matrix(rbinom(5000 * 2, size = 2, prob = 0.3), nrow = 5000)
microbenchmark((t(B) %*% A) %*% B, t(B) %*% (A %*% B), times = 100)
Run Code Online (Sandbox Code Playgroud)
看起来这是一个实现以及元素访问顺序的问题。请参阅迭代 2D 数组时为什么循环顺序会影响性能?。
\n当做同样的事情但改变循环的顺序时,时间是不同的。
\nRcpp::cppFunction("NumericMatrix mm(const NumericMatrix& A, const NumericMatrix& B) {\nint M = A.nrow();\nint N = A.ncol();\nint P = B.ncol();\nNumericMatrix res(M, P);\nfor (int n=0; n<N; ++n) { //Loop n, p, m\n for (int p=0; p<P; ++p) {\n for (int m=0; m<M; ++m) {\n res[m+p*M] += A[m+M*n] * B[p*N+n];\n }\n }\n}\nreturn res;}")\n\nRcpp::cppFunction("NumericMatrix mm2(const NumericMatrix& A, const NumericMatrix& B) {\nint M = A.nrow();\nint N = A.ncol();\nint P = B.ncol();\nNumericMatrix res(M, P);\nfor (int m=0; m<M; ++m) { //Loop m, p, n\n for (int p=0; p<P; ++p) {\n for (int n=0; n<N; ++n) {\n res[m+p*M] += A[m+M*n] * B[p*N+n];\n }\n }\n}\nreturn res;}")\n
Run Code Online (Sandbox Code Playgroud)\nk <- 5000L; m <- n <- 2L;\nA <- matrix(rnorm(k * k), k, k);\nB <- matrix(rnorm(k * n), k, n);\ntB <- t(B);\n\nmet <- alist("(tB*A)B" = tB %*% A %*% B,\n "tB(A*B)" = tB %*% (A %*% B),\n "mm (tB*A)B" = mm(mm(tB, A), B),\n "mm tB(A*B)" = mm(tB, mm(A, B)),\n "mm2 (tB*A)B" = mm2(mm2(tB, A), B),\n "mm2 tB(A*B)" = mm2(tB, mm2(A, B)),\n "cp(B,A)B" = crossprod(B, A) %*% B,\n "cp(B,A*B)" = crossprod(B, A %*% B) )\n
Run Code Online (Sandbox Code Playgroud)\nbench::mark(exprs = met)\n# expression min median itr/s\xe2\x80\xa6\xc2\xb9 mem_a\xe2\x80\xa6\xc2\xb2 gc/se\xe2\x80\xa6\xc2\xb3 n_itr n_gc total\xe2\x80\xa6\xe2\x81\xb4 result \n# <bch:expr> <bch:> <bch:> <dbl> <bch:b> <dbl> <int> <dbl> <bch:t> <list> \n#1 (tB*A)B 79.5ms 80.1ms 12.5 78.17KB 0 7 0 562ms <dbl[\xe2\x80\xa6]>\n#2 tB(A*B) 33.8ms 34.4ms 28.4 78.17KB 0 15 0 528ms <dbl[\xe2\x80\xa6]>\n#3 mm (tB*A)B 61.9ms 62.5ms 15.9 3.85MB 0 8 0 502ms <dbl[\xe2\x80\xa6]>\n#4 mm tB(A*B) 19.9ms 20.7ms 48.1 83.16KB 0 25 0 520ms <dbl[\xe2\x80\xa6]>\n#5 mm2 (tB*A)B 35.9ms 39.4ms 25.8 87.29KB 0 13 0 504ms <dbl[\xe2\x80\xa6]>\n#6 mm2 tB(A*B) 47.8ms 48.1ms 20.6 83.16KB 0 11 0 535ms <dbl[\xe2\x80\xa6]>\n#7 cp(B,A)B 44.1ms 44.5ms 22.4 80.42KB 0 12 0 536ms <dbl[\xe2\x80\xa6]>\n#8 cp(B,A*B) 34.1ms 36.5ms 27.1 78.17KB 0 14 0 516ms <dbl[\xe2\x80\xa6]>\n\nmicrobenchmark::microbenchmark(list = met)\n#Unit: milliseconds\n# expr min lq mean median uq max neval\n# (tB*A)B 77.09484 77.86891 79.09483 78.44832 80.08971 87.05563 100\n# tB(A*B) 33.63306 34.22562 36.08482 35.14064 36.64080 51.39962 100\n# mm (tB*A)B 62.05235 64.14361 66.54568 65.16927 67.98617 75.96242 100\n# mm tB(A*B) 19.67066 20.28369 20.83781 20.53820 21.19940 23.64119 100\n# mm2 (tB*A)B 35.31290 35.70006 36.62846 36.10282 37.41669 40.47473 100\n# mm2 tB(A*B) 48.16574 49.70702 51.55844 50.26292 52.46479 67.44558 100\n# cp(B,A)B 43.18166 44.01366 45.28434 44.71301 46.41521 48.97891 100\n# cp(B,A*B) 33.62158 34.47070 35.84743 35.11853 36.55979 48.89021 100\n
Run Code Online (Sandbox Code Playgroud)\n
归档时间: |
|
查看次数: |
528 次 |
最近记录: |