MATLAB是否优化了诊断(A*B)?

Emi*_*erg 7 matlab matrix linear-algebra

假设我有两个非常大的矩阵A(M-by-N)和B(N-by-M).我需要对角线A*B.计算完全A*B需要M*M*N次乘法,而计算它的对角线只需要M*N次乘法,因为不需要计算最终会在对角线之外的元素.

MATLAB是否实现了这种diag(A*B)自动优化和动态优化,或者我最好在这种情况下使用for循环?

Div*_*kar 11

人们也可以实现diag(A*B)sum(A.*B',2).让我们根据此问题的建议对此进行基准测试以及所有其他实现/解决方案.

下面列出了作为函数实现的不同方法,用于基准测试:

  1. 求和方法-1

    function out = sum_mult_method1(A,B)
    
    out = sum(A.*B',2);
    
    Run Code Online (Sandbox Code Playgroud)
  2. 求和方法-2

    function out = sum_mult_method2(A,B)
    
    out = sum(A.'.*B).';
    
    Run Code Online (Sandbox Code Playgroud)
  3. For循环方法

    function out = for_loop_method(A,B)
    
    M = size(A,1);
    out = zeros(M,1);
    for i=1:M
        out(i) = A(i,:) * B(:,i);
    end
    
    Run Code Online (Sandbox Code Playgroud)
  4. 全/直接乘法

    function out = direct_mult_method(A,B)
    
    out = diag(A*B);
    
    Run Code Online (Sandbox Code Playgroud)
  5. Bsxfun法

    function out = bsxfun_method(A,B)
    
    out = sum(bsxfun(@times,A,B.'),2);
    
    Run Code Online (Sandbox Code Playgroud)

基准代码

num_runs = 1000;
M_arr = [100 200 500 1000];
N = 4;

%// Warm up tic/toc.
tic();
elapsed = toc();
tic();
elapsed = toc();

for k2 = 1:numel(M_arr)
    M = M_arr(k2);

    fprintf('\n')
    disp(strcat('*** Benchmarking sizes are M =',num2str(M),' and N = ',num2str(N)));

    A = randi(9,M,N);
    B = randi(9,N,M);

    disp('1. Sum-multiplication method-1');
    tic
    for k = 1:num_runs
        out1 = sum_mult_method1(A,B);
    end
    toc
    clear out1

    disp('2. Sum-multiplication method-2');
    tic
    for k = 1:num_runs
        out2 = sum_mult_method2(A,B);
    end
    toc
    clear out2

    disp('3. For-loop method');
    tic
    for k = 1:num_runs
        out3 = for_loop_method(A,B);
    end
    toc
    clear out3

    disp('4. Direct-multiplication method');
    tic
    for k = 1:num_runs
        out4 = direct_mult_method(A,B);
    end
    toc
    clear out4

    disp('5. Bsxfun method');
    tic
    for k = 1:num_runs
        out5 = bsxfun_method(A,B);
    end
    toc
    clear out5

end
Run Code Online (Sandbox Code Playgroud)

结果

*** Benchmarking sizes are M =100 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.015242 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.015180 seconds.
3. For-loop method
Elapsed time is 0.192021 seconds.
4. Direct-multiplication method
Elapsed time is 0.065543 seconds.
5. Bsxfun method
Elapsed time is 0.054149 seconds.

*** Benchmarking sizes are M =200 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.009138 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.009428 seconds.
3. For-loop method
Elapsed time is 0.435735 seconds.
4. Direct-multiplication method
Elapsed time is 0.148908 seconds.
5. Bsxfun method
Elapsed time is 0.030946 seconds.

*** Benchmarking sizes are M =500 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.033287 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.026405 seconds.
3. For-loop method
Elapsed time is 0.965260 seconds.
4. Direct-multiplication method
Elapsed time is 2.832855 seconds.
5. Bsxfun method
Elapsed time is 0.034923 seconds.

*** Benchmarking sizes are M =1000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.026068 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.032850 seconds.
3. For-loop method
Elapsed time is 1.775382 seconds.
4. Direct-multiplication method
Elapsed time is 13.764870 seconds.
5. Bsxfun method
Elapsed time is 0.044931 seconds.
Run Code Online (Sandbox Code Playgroud)

中级结论

看起来sum-multiplication方法是最好的方法,但bsxfun方法似乎是追赶它们M从100增加到1000.

接下来,仅使用sum-multiplicationbsxfun方法测试更高的基准测试大小.尺寸是 -

M_arr = [1000 2000 5000 10000 20000 50000];
Run Code Online (Sandbox Code Playgroud)

结果是 -

*** Benchmarking sizes are M =1000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.030390 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.032334 seconds.
5. Bsxfun method
Elapsed time is 0.047377 seconds.

*** Benchmarking sizes are M =2000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.040111 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.045132 seconds.
5. Bsxfun method
Elapsed time is 0.060762 seconds.

*** Benchmarking sizes are M =5000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.099986 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.103213 seconds.
5. Bsxfun method
Elapsed time is 0.117650 seconds.

*** Benchmarking sizes are M =10000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.375604 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.273726 seconds.
5. Bsxfun method
Elapsed time is 0.226791 seconds.

*** Benchmarking sizes are M =20000 and N =4
1. Sum-multiplication method-1
Elapsed time is 1.906839 seconds.
2. Sum-multiplication method-2
Elapsed time is 1.849166 seconds.
5. Bsxfun method
Elapsed time is 1.344905 seconds.

*** Benchmarking sizes are M =50000 and N =4
1. Sum-multiplication method-1
Elapsed time is 5.159177 seconds.
2. Sum-multiplication method-2
Elapsed time is 5.081211 seconds.
5. Bsxfun method
Elapsed time is 3.866018 seconds.
Run Code Online (Sandbox Code Playgroud)

替代基准测试代码(带有`timeit)

num_runs = 1000;
M_arr = [1000 2000 5000 10000 20000 50000 100000 200000 500000 1000000];
N = 4;

timeall = zeros(5,numel(M_arr));
for k2 = 1:numel(M_arr)
    M = M_arr(k2);

    A = rand(M,N);
    B = rand(N,M);

    f = @() sum_mult_method1(A,B);
    timeall(1,k2) = timeit(f);
    clear f

    f = @() sum_mult_method2(A,B);
    timeall(2,k2) = timeit(f);
    clear f

    f = @() bsxfun_method(A,B);
    timeall(5,k2) = timeit(f);
    clear f

end

figure,
hold on
plot(M_arr,timeall(1,:),'-ro')
plot(M_arr,timeall(2,:),'-ko')
plot(M_arr,timeall(5,:),'-.b')
legend('sum-method1','sum-method2','bsxfun-method')
xlabel('M ->')
ylabel('Time(sec) ->')
Run Code Online (Sandbox Code Playgroud)

情节

在此输入图像描述

最终结论

似乎sum-multiplication方法很好,直到某个阶段,在M=5000标记周围,之后bsxfun似乎有轻微的上风.

未来的工作

人们可以研究变化N并研究这里提到的实现的性能.

  • 这些时间看起来很小,可靠,而且你没有做任何热身.如果您不熟悉可靠的基准测试,但是您最近有一个Matlab,我建议使用`timeit`.另外,测试`bsxfun`怎么样? (2认同)