可以加速特殊 3 数组的计算吗?

g g*_*g g 6 performance r vectorization rcpp tensor

我想从暗淡 (K,N) 的两个矩阵 A 和暗淡 (d,N) 的 X 确定一个维度为 (K,d,d) 的 3 数组 R,其中 K 很小,d 适中,但 N 是大(有关典型值,请参阅下面的代码示例)。数组的公式是

R[k, i, j] = sum( A[k, ] * X[i, ] * X[j, ] )。

该数组必须计算多次,因此速度至关重要。因此,我想知道在 R 中计算这个最有效的方法是什么?

我目前的做法

我当前的方法在下面被列为“当前”和“天真的”方法,毫不奇怪,它的速度要慢得多。

library(microbenchmark)

K = 3
d = 20
N = 1e5

tt = microbenchmark(
  
  current = {
    for(krow in 1:K){
      tmp = X * matrix(A[krow,], d, N, byrow = TRUE)
      R[krow,,] = tmp %*% t(X)  
    }},
  
  naive = {
    for(krow in 1:K){
      for(irow in 1:d){
        for(jrow in 1:d){
          Ralt[krow, irow, jrow] = sum(A[krow,] * X[irow, ] * X[jrow,])
        }
      }
    }},

  check = "equal",
  
  setup = {
    A = matrix(runif(K*N), K, N)
    X = matrix(runif(d*N), d, N)
    R = array(0, dim = c(K, d, d))
    Ralt = array(0, dim = c(K, d, d))
  },
  times = 5
)

print(tt)
Run Code Online (Sandbox Code Playgroud)

问题

  • 您认为有什么方法可以改进这一点吗?例如,是否可以利用 R 在最后两个索引中是对称的这一事实?
  • 在 Rcpp 中实现此功能是否可以带来显着的改进(>30%)?

GKi*_*GKi 6

您可以转置t矩阵以启用列子集,这比行子集更快。这允许自动重复而不是创建新的矩阵。

tX <- t(X)
tA <- t(A)
for(krow in 1:K){
    . <- tX * tA[,krow]
    R[krow,,] <- t(.) %*% tX
}
Run Code Online (Sandbox Code Playgroud)

一个变体可能看起来像:

tX <- t(X)
tA <- t(A)
for(krow in 1:K) R[krow,,] <- crossprod(tX * tA[,krow], tX)
Run Code Online (Sandbox Code Playgroud)

在可能的情况下可以加快速度,crossprod例如Rfast::Crossprod(坦克到@jblood94的评论)。

Rcpp 变体可能看起来像(但目前比其他变体慢):

Rcpp::cppFunction(r"(void mmul(Rcpp::NumericMatrix A, Rcpp::NumericMatrix X, Rcpp::NumericVector R, int K, int d) {
  int KD = d*K;
  for(int i=0; i < d; ++i) {
    for(int j=0; j < d; ++j) {
      Rcpp::NumericVector tmp = X(_,i) * X(_,j);
      for(int k=0; k < K; ++k) {
        R[k + i*K + j*KD] = sum(A(_,k) * tmp);
      }
    }
  }
} )")

mmul(t(A), t(X), R, K, d)
Run Code Online (Sandbox Code Playgroud)

一个使用特征值:

Rcpp::sourceCpp(code=r"(
// [[Rcpp::depends(RcppEigen)]]
// [[Rcpp::plugins(openmp)]]

#include <omp.h>
#include <RcppEigen.h>

using namespace std;
using namespace Eigen;

// [[Rcpp::export]]
void mmulE(Eigen::MatrixXd A, Eigen::MatrixXd X, Rcpp::NumericVector R, int n_cores) {
  Eigen::setNbThreads(n_cores);
  for(int k=0; k < A.cols(); ++k) {
    Eigen::MatrixXd C = X.cwiseProduct(A.col(k).replicate(1, X.cols() ));
    Eigen::MatrixXd D = C.transpose() * X;
    for(int i=0; i<D.size(); ++i) {
      R[i*A.cols()+k] = D(i);
    }
  }
}
)")

mmulE(t(A), t(X), R, 1)
Run Code Online (Sandbox Code Playgroud)
library(microbenchmark)

K = 3
d = 20
N = 1e5

tt = microbenchmark(
  
  current = {
    for(krow in 1:K){
      tmp = X * matrix(A[krow,], d, N, byrow = TRUE)
      R[krow,,] = tmp %*% t(X)  
    }},
  
  GKi = {
      tX <- t(X)
      tA <- t(A)
      for(krow in 1:K){
          . <- tX * tA[,krow]
          R[krow,,] <- t(.) %*% tX
      }
  },

  crossp = {
      tX <- t(X)
      tA <- t(A)
      for(krow in 1:K) R[krow,,] <- crossprod(tX * tA[,krow], tX)
  },

  Rfast = {
    tX <- t(X)
    tA <- t(A)
    for(krow in 1:K) R[krow,,] <- Rfast::Crossprod(tX*tA[,krow], tX)
  },

  Rcpp = mmul(t(A), t(X), R, K, d),

  RcppE1C = mmulE(t(A), t(X), R, 1),
  RcppE2C = mmulE(t(A), t(X), R, 2),
  RcppE4C = mmulE(t(A), t(X), R, 4),

  check = "equal",
  
  setup = {
    A = matrix(runif(K*N), K, N)
    X = matrix(runif(d*N), d, N)
    R = array(0, dim = c(K, d, d))
  },
  times = 5
)

print(tt)
Run Code Online (Sandbox Code Playgroud)
Unit: milliseconds
    expr       min        lq      mean    median        uq       max neval
 current 106.44215 108.73900 161.66269 159.30184 216.37502 217.45546     5
     GKi  84.56926  87.98166 111.04126  90.18420  97.30869 195.16249     5
  crossp 112.02929 113.01796 113.67749 113.93593 114.49450 114.90976     5
   Rfast  39.12859  42.11124  45.42296  46.83398  49.46175  49.57924     5
    Rcpp 156.28284 156.38025 182.19358 157.05552 159.86193 281.38735     5
 RcppE1C  38.94770  40.49375  42.71140  40.69852  46.57995  46.83707     5
 RcppE2C  35.03088  35.67732  36.73970  36.52070  36.64065  39.82895     5
 RcppE4C  31.40532  33.94128  34.53725  34.40168  34.64187  38.29608     5
Run Code Online (Sandbox Code Playgroud)

也许还可以看看:
Crossprod 比 %*% 慢,为什么?
如何
在R中使crossprod更快快速大矩阵乘法