matlab/octave - 广义矩阵乘法

gab*_*ous 10 matlab matrix octave matrix-multiplication

我想做一个函数来推广矩阵乘法.基本上,它应该能够进行标准矩阵乘法,但它应该允许通过任何其他函数更改两个二元运算符product/sum.

目标是在CPU和内存方面尽可能高效.当然,它总是比A*B效率低,但操作员的灵活性才是最重要的.

以下是我在阅读各种 有趣 线程后可以提出的一些命令:

A = randi(10, 2, 3);
B = randi(10, 3, 4);

% 1st method
C = sum(bsxfun(@mtimes, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% Alternative: C = bsxfun(@(a,b) mtimes(a',b), A', permute(B, [1 3 2]))

% 2nd method
C = sum(bsxfun(@(a,b) a*b, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)

% 3rd method (Octave-only)
C = sum(permute(A, [1 3 2]) .* permute(B, [3 2 1]), 3)

% 4th method (Octave-only): multiply nxm A with nx1xd B to create a nxmxd array
C = bsxfun(@(a, b) sum(times(a,b)), A', permute(B, [1 3 2]));
C = C2 = squeeze(C(1,:,:)); % sum and turn into mxd
Run Code Online (Sandbox Code Playgroud)

方法1-3的问题在于它们将在使用sum()折叠它们之前生成n个矩阵.4更好,因为它在bsxfun中执行sum(),但是bsxfun仍然生成n个矩阵(除了它们大部分是空的,只包含一个非零值向量的总和,其余的用0填充以匹配尺寸要求).

我想要的是像第四种方法,但没有无用的0来节省内存.

任何的想法?

Amr*_*mro 4

这是您发布的解决方案的稍微完善的版本,有一些小的改进。

我们检查行数是否多于列数,或者反之亦然,然后通过选择将行与矩阵相乘或将矩阵与列相乘(从而进行最少的循环迭代)来相应地进行乘法。

A*B

注意:即使行数少于列数,这也可能并不总是最好的策略(按行而不是按列);事实上,MATLAB 数组在内存中按列优先顺序存储,因此按列切片的效率更高,因为元素是连续存储的。而访问行涉及按步幅遍历元素(这对缓存不友好——想想空间局部性)。

除此之外,代码应该处理双精度/单精度、实数/复数、完整/稀疏(以及不可能组合的错误)。它还尊重空矩阵和零维度。

function C = my_mtimes(A, B, outFcn, inFcn)
    % default arguments
    if nargin < 4, inFcn = @times; end
    if nargin < 3, outFcn = @sum; end

    % check valid input
    assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
    assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
    assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
        'Expecting function handles.')

    % preallocate output matrix
    M = size(A,1);
    N = size(B,2);
    if issparse(A)
        args = {'like',A};
    elseif issparse(B)
        args = {'like',B};
    else
        args = {superiorfloat(A,B)};
    end
    C = zeros(M,N, args{:});

    % compute matrix multiplication
    % http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
    if M < N
        % concatenation of products of row vectors with matrices
        % A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
        for m=1:M
            %C(m,:) = A(m,:) * B;
            %C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
            C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
        end
    else
        % concatenation of products of matrices with column vectors
        % A*B = [A*b_1 , A*b_2 , ... , A*b_n]
        for n=1:N
            %C(:,n) = A * B(:,n);
            %C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
            C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
        end
    end
end
Run Code Online (Sandbox Code Playgroud)

比较

毫无疑问,该函数整体上速度较慢,但​​对于较大的尺寸,它比内置矩阵乘法差几个数量级:

        (tic/toc times in seconds)
      (tested in R2014a on Windows 8)

    size      mtimes       my_mtimes 
    ____    __________     _________
     400     0.0026398       0.20282
     600      0.012039       0.68471
     800      0.014571        1.6922
    1000      0.026645        3.5107
    2000       0.20204         28.76
    4000        1.5578        221.51
Run Code Online (Sandbox Code Playgroud)

mtimes_vs_mymtimes

这是测试代码:

sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
    n = sz(i); disp(n)
    A = rand(n,n);
    B = rand(n,n);

    tic
    C = A*B;
    t(i,1) = toc;
    tic
    D = my_mtimes(A,B);
    t(i,2) = toc;

    assert(norm(C-D) < 1e-6)
    clear A B C D
end

semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight
Run Code Online (Sandbox Code Playgroud)

额外的

为了完整起见,下面是实现广义矩阵乘法的两种更简单的方法(如果您想比较性能,请将函数的最后部分替换my_mtimes为其中任何一个)。我什至都懒得去贴出他们过去的时间:)

C = zeros(M,N, args{:});
for m=1:M
    for n=1:N
        %C(m,n) = A(m,:) * B(:,n);
        %C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
        C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
    end
end
Run Code Online (Sandbox Code Playgroud)

另一种方式(使用三环):

C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
    for n=1:N
        for p=1:P
            %C(m,n) = C(m,n) + A(m,p)*B(p,n);
            %C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
            C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
        end
    end
end
Run Code Online (Sandbox Code Playgroud)

接下来要尝试什么?

如果您想获得更多性能,则必须转向 C/C++ MEX 文件,以减少解释 MATLAB 代码的开销。您仍然可以通过从 MEX 文件调用优化的 BLAS/LAPACK 例程来利用它们(有关示例,请参阅本文的第二部分)。MATLAB 附带Intel MKL库,坦率地说,当涉及到 Intel 处理器上的线性代数计算时,该库是无与伦比的。

其他人已经提到了文件交换上的一些提交,这些提交将通用矩阵例程实现为 MEX 文件(请参阅@natan的答案)。如果将它们链接到优化的 BLAS 库,这些将特别有效。