如何对此代码进行矢量化?

Jam*_*ame 12 matlab vectorization

我写了一个递归函数,然而,它需要花费很多时间.因此我对它进行了矢量化,但它不会产生与递归函数相同的结果.这是我的非矢量化代码:

function visited = procedure_explore( u, adj_mat, visited )
visited(u) = 1;
neighbours = find(adj_mat(u,:));
for ii = 1:length(neighbours)
    if (visited(neighbours(ii)) == 0)
        visited = procedure_explore( neighbours(ii), adj_mat, visited );
    end
end
end
Run Code Online (Sandbox Code Playgroud)

这是我的矢量化代码:

function visited = procedure_explore_vec( u, adj_mat, visited )
visited(u) = 1;
neighbours = find(adj_mat(u,:));
len_neighbours=length(neighbours);
visited_neighbours_zero=visited(neighbours(1:len_neighbours)) == 0;
if(~isempty(visited_neighbours_zero))
    visited = procedure_explore_vec( neighbours(visited_neighbours_zero), adj_mat, visited );
end
end
Run Code Online (Sandbox Code Playgroud)

这是测试代码

function main
    adj_mat=[0 0 0 0;
             1 0 1 1;
             1 0 0 0;
             1 0 0 1];
    u=2;
    visited=zeros(size(adj_mat,1));
    tic
    visited = procedure_explore( u, adj_mat, visited )
    toc
    visited=zeros(size(adj_mat,1));
    tic
    visited = procedure_explore_vec( u, adj_mat, visited )
    toc
end
Run Code Online (Sandbox Code Playgroud)

这是我想要实现的算法: 在此输入图像描述

如果不能进行矢量化,那么mex解决方案也会很好.

更新基准:此基准测试基于MATLAB 2017a.它表明原始代码比其他方法更快

Speed up between original and logical methods is 0.39672
Speed up between original and nearest methods is 0.0042583
Run Code Online (Sandbox Code Playgroud)

完整代码

function main_recersive
    adj_mat=[0 0 0 0;
             1 0 1 1;
             1 0 0 0;
             1 0 0 1];
    u=2;
    visited=zeros(size(adj_mat,1));
    f_original=@()(procedure_explore( u, adj_mat, visited ));
    t_original=timeit(f_original);

    f_logical=@()(procedure_explore_logical( u, adj_mat ));
    t_logical=timeit(f_logical);

    f_nearest=@()(procedure_explore_nearest( u, adj_mat,visited ));
    t_nearest=timeit(f_nearest);

    disp(['Speed up between original and logical methods is ',num2str(t_original/t_logical)])
    disp(['Speed up between original and nearest methods is ',num2str(t_original/t_nearest)])    

end

function visited = procedure_explore( u, adj_mat, visited )
    visited(u) = 1;
    neighbours = find(adj_mat(u,:));
    for ii = 1:length(neighbours)
        if (visited(neighbours(ii)) == 0)
            visited = procedure_explore( neighbours(ii), adj_mat, visited );
        end
    end
end

function visited = procedure_explore_nearest( u, adj_mat, visited )
    % add u since your function also includes it.
    nodeIDs = [nearest(digraph(adj_mat),u,inf) ; u];
    % transform to output format of your function
    visited = zeros(size(adj_mat,1));
    visited(nodeIDs) = 1;

end 

function visited = procedure_explore_logical( u, adj_mat )
   visited = false(1, size(adj_mat, 1));
   visited(u) = true;
   new_visited = visited;
   while any(new_visited)
      visited = any([visited; new_visited], 1);
      new_visited = any(adj_mat(new_visited, :), 1);
      new_visited = and(new_visited, ~visited);
   end
end
Run Code Online (Sandbox Code Playgroud)

bea*_*ker 4

这是一个有趣的小函数,它在图上进行非递归广度优先搜索。

function visited = procedure_explore_logical( u, adj_mat )
   visited = false(1, size(adj_mat, 1));
   visited(u) = true;
   new_visited = visited;

   while any(new_visited)
      visited = any([visited; new_visited], 1);
      new_visited = any(adj_mat(new_visited, :), 1);
      new_visited = and(new_visited, ~visited);
   end
end
Run Code Online (Sandbox Code Playgroud)

在 Octave 中,它的运行速度比 100x100 邻接矩阵上的递归版本快约 50 倍。您必须在 MATLAB 上对其进行基准测试才能看到结果。