MATLAB中的高效树实现

mik*_*ola 16 tree performance matlab data-structures

MATLAB中的树类

我正在MATLAB中实现树数据结构.将新的子节点添加到树中,分配和更新与节点相关的数据值是我期望执行的典型操作.每个节点都具有data与之关联的相同类型.删除节点对我来说不是必需的.到目前为止,我已经决定继承自handle类的类实现,以便能够将对节点的引用传递给将修改树的函数.

编辑:12月2日

首先,感谢到目前为止评论和答案中的所有建议.他们已经帮助我改进了我的树类.

有人建议尝试digraph在R2015b中引入.我还没有探索过这个,但是因为它不像继承的类那样作为引用参数工作handle,我有点怀疑它在我的应用程序中是如何工作的.在这一点上,我还不清楚使用自定义data节点和边缘来使用它是多么容易.

编辑:(12月3日)有关主要应用程序的更多信息:MCTS

最初,我认为主要应用程序的细节只是边际兴趣,但自从阅读@FirefoxMetzger 的评论和答案后,我意识到它具有重要意义.

我正在实现一种蒙特卡罗树搜索算法.搜索树以迭代方式被探索和扩展.维基百科提供了一个很好的过程图形概述: 蒙特卡罗树搜索

在我的应用程序中,我执行了大量的搜索迭代.在每次搜索迭代中,我遍历当前树,从根开始直到叶节点,然后通过添加新节点展开树,并重复.由于该方法基于随机采样,因此在每次迭代开始时,我不知道在每次迭代时我将完成哪个叶节点.相反,这是由data当前在树中的节点和随机样本的结果共同确定的.我在一次迭代中访问的节点都有data更新.

示例:我在n有几个孩子的节点.我需要访问每个孩子中的数据并绘制一个随机样本,以确定我在搜索中移动到下一个孩子.重复此过程,直到到达叶节点.实际上,我通过调用searchroot上的函数来执行此操作,该函数将决定接下来要扩展哪个子search节点,递归调用该节点,依此类推,最后在到达叶节点时返回一个值.从递归函数返回时使用此值以更新data在搜索迭代期间访问的节点.

树可能非常不平衡,使得一些分支是非常长的节点链,而其他分支在根级之后快速终止并且不进一步扩展.

目前的实施

下面是我当前实现的示例,其中包含一些用于添加节点,查询树中节点深度或数量等的成员函数的示例,以此类推.

classdef stree < handle
    %   A class for a tree object that acts like a reference
    %   parameter.
    %   The tree can be traversed in both directions by using the parent
    %   and children information.
    %   New nodes can be added to the tree. The object will automatically
    %   keep track of the number of nodes in the tree and increment the
    %   storage space as necessary.

    properties (SetAccess = private)
        % Hold the data at each node
        Node = { [] };
        % Index of the parent node. The root of the tree as a parent index
        % equal to 0.
        Parent = 0;
        num_nodes = 0;
        size_increment = 1;
        maxSize = 1;
    end

    methods
        function [obj, root_ID] = stree(data, init_siz)
            % New object with only root content, with specified initial
            % size
            obj.Node = repmat({ data },init_siz,1);
            obj.Parent = zeros(init_siz,1);
            root_ID = 1;
            obj.num_nodes = 1;
            obj.size_increment = init_siz;
            obj.maxSize = numel(obj.Parent);
        end

        function ID = addnode(obj, parent, data)
            % Add child node to specified parent
            if obj.num_nodes < obj.maxSize
                % still have room for data
                idx = obj.num_nodes + 1;
                obj.Node{idx} = data;
                obj.Parent(idx) = parent;
                obj.num_nodes = idx;
            else
                % all preallocated elements are in use, reserve more memory
                obj.Node = [
                    obj.Node
                    repmat({data},obj.size_increment,1)
                    ];

                obj.Parent = [
                    obj.Parent
                    parent
                    zeros(obj.size_increment-1,1)];
                obj.num_nodes = obj.num_nodes + 1;

                obj.maxSize = numel(obj.Parent);

            end
            ID = obj.num_nodes;
        end

        function content = get(obj, ID)
            %% GET  Return the contents of the given node IDs.
            content = [obj.Node{ID}];
        end

        function obj = set(obj, ID, content)
            %% SET  Set the content of given node ID and return the modifed tree.
            obj.Node{ID} = content;
        end

        function IDs = getchildren(obj, ID)
            % GETCHILDREN  Return the list of ID of the children of the given node ID.
            % The list is returned as a line vector.
            IDs = find( obj.Parent(1:obj.num_nodes) == ID );
            IDs = IDs';
        end
        function n = nnodes(obj)
            % NNODES  Return the number of nodes in the tree.
            % Equal to root + those whose parent is not root.
            n = 1 + sum(obj.Parent(1:obj.num_nodes) ~= 0);
            assert( obj.num_nodes == n);
        end

        function flag = isleaf(obj, ID)
            % ISLEAF  Return true if given ID matches a leaf node.
            % A leaf node is a node that has no children.
            flag = ~any( obj.Parent(1:obj.num_nodes) == ID );
        end

        function depth = depth(obj,ID)
            % DEPTH return depth of tree under ID. If ID is not given, use
            % root.
            if nargin == 1
                ID = 0;
            end
            if obj.isleaf(ID)
                depth = 0;
            else
                children = obj.getchildren(ID);
                NC = numel(children);
                d = 0; % Depth from here on out
                for k = 1:NC
                    d = max(d, obj.depth(children(k)));
                end
                depth = 1 + d;
            end
        end
    end
end
Run Code Online (Sandbox Code Playgroud)

但是,性能有时很慢,树上的操作占用了我大部分的计算时间.有哪些具体方法可以提高实施效率?handle如果有性能提升,甚至可以将实现更改为继承类型之外的其他内容.

当前实施的分析结果

由于向树中添加新节点是最典型的操作(同时更新data节点),我对此进行了一些分析.我使用以下基准测试代码运行探查器Nd=6, Ns=10.

function T = benchmark(Nd, Ns)
% Tree benchmark. Nd: tree depth, Ns: number of nodes per layer
% Initialize tree
T = stree(rand, 10000);
add_layers(1, Nd);
    function add_layers(node_id, num_layers)
        if num_layers == 0
            return;
        end
        child_id = zeros(Ns,1);
        for s = 1:Ns
            % add child to current node
            child_id(s) = T.addnode(node_id, rand);

            % recursively increase depth under child_id(s)
            add_layers(child_id(s), num_layers-1);
        end
    end
end
Run Code Online (Sandbox Code Playgroud)

分析器的结果: Profiler结果

R2015b性能


已经发现R2015b 改善了MATLAB的OOP功能的性能.我重新调整了上述基准,确实观察到了性能的提高:

R2015b剖析器结果

所以这已经是好消息,虽然当然可以接受进一步的改进;)

以不同方式保留记忆

评论中也建议使用

obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];
Run Code Online (Sandbox Code Playgroud)

保留更多的内存而不是当前的方法repmat.这略微改善了性能.我应该注意我的基准代码是针对虚拟数据的,因为实际上data更复杂,这可能会有所帮助.谢谢!Profiler结果如下:

zeeMonkeez内存储备风格

关于进一步提高绩效的问题

  1. 也许还有另一种方法来维护更高效的树的内存?遗憾的是,我通常不知道树中将有多少节点.
  2. 添加新节点和修改data现有节点是我在树上执行的最典型的操作.截至目前,它们实际上占用了我主要应用程序的大部分处理时间.对这些功能的任何改进都是最受欢迎的.

最后一点,我希望将实现保持为纯MATLAB.但是,诸如MEX之类的选项或使用某些集成的Java功能可能是可以接受的.

小智 9

TL:DR您深度复制存储在每个插入上的整个数据,初始化parentNode单元格大于您期望的需要.

您的数据确实具有树结构,但是您在实现中没有使用它.相反,实现的代码是查找表(实际上是2个表)的计算饥饿版本,它存储数据和树的关系数据.

我说这个的原因如下:

  • 要插入调用stree.addnote(母公司数据),将所有数据存储在树对象stree的字段Node = {}Parent = []
  • 你似乎知道你想要访问树中的哪个元素,因为没有给出搜索代码(如果你使用stree.getchild(ID),我有一些坏消息)
  • 处理完节点后,使用find()哪个节点进行列表搜索

这绝不意味着数据的实现是笨拙的,它甚至可能是最好的,这取决于你正在做什么.但它确实解释了您的内存分配问题,并提供了有关如何解决它们的提示.


将数据保留为查找表

存储数据的方法之一是保留基础查找表.我只会这样做,如果您知道ID要修改的第一个元素而不搜索它.这种情况允许您通过两个步骤提高结构效率.

首先将您的阵列初始化为大于您期望存储数据所需的数据.如果超出了查找表的容量,则初始化一个新的容量,即X字段更大,并生成旧数据的深层副本.如果您需要扩展一次或两次capcity(在所有插入期间),这可能不是问题,但在您的情况下,需要进行深层复制以进行插入!

其次,我会改变内部结构和合并两个表NodeParent.这样做的原因是代码中的反向传播需要O(depth_from_root*n),其中n是表中的节点数.这是因为find()将遍历每个父级的整个表.

相反,你可以实现类似的东西

table = cell(n,1) % n bigger then expected value
end_pointer = 1 % simple pointer to the first free value

function insert(data,parent_ID)
    if end_pointer < numel(table)
        content.data = data;
        content.parent = parent_ID;
        table{end_pointer} = content;
        end_pointer = end_pointer + 1;
    else
        % need more space, make sure its enough this time
        table = [table cell(end_pointer,1)];
        insert(data,parent_ID);
    end
end

function content = get_value(ID)
    content = table(ID);
end
Run Code Online (Sandbox Code Playgroud)

这样可以立即让您ID无需find()先访问父项,每次保存n次迭代,因此负担变为O(深度).如果你不知道你的初始节点,那么你必须知道find()那个,这需要花费O(n).

需要注意的是这种结构不需要is_leaf(),depth(),nnodes()get_children().如果您仍然需要那些我需要更深入了解您想要对数据做什么,因为这会极大地影响正确的结构.


树结构

如果您永远不知道第一个节点,那么这种结构是有意义的ID,因此总是必须搜索它.

好处是搜索任意音符与O(深度)一起工作,因此搜索是O(深度)而不是O(n),反向传播是O(深度^ 2)而不是O(深度+ n).请注意,对于完美平衡的树,深度可以是log(n),根据您的数据,可以是n,也可以是退化树的n,只是链接列表.

然而,为了建议一些正确的东西,我需要更多的洞察力,因为每个树形结构都有自己的nich.从我到目前为止看到的情况来看,我建议使用一个不平衡的树,它由一个节点通缉父节点给出的简单顺序"排序".这可以根据具体情况进一步优化

  • 是否可以定义数据的总订单
  • 你如何对待双重值(相同的数据出现两次)
  • 你的数据规模是多少(数千,数百万......)
  • 是一个与反向传播配对的查找/搜索
  • 你的数据上的'亲子'链是多长时间(或者树使用这个简单的顺序有多平衡和深度)
  • 总是只有一个父母或是同一个元素插入两次与不同的父母

我很乐意为上面的树提供示例代码,请给我留言.

编辑:在你的情况下,一个不平衡的树(这是建立在MCTS上的并列)似乎是最好的选择.下面的代码假定数据是分开的state,score并且进一步表示a state是唯一的.如果不是这样仍然有效,那么可能会优化以提高MCTS的性能.

classdef node < handle
    % A node for a tree in a MCTS
    properties
        state = {}; %some state of the search space that identifies the node
        score = 0;
        childs = cell(50,1);
        num_childs = 0;
    end
    methods
        function obj = node(state)
            % for a new node simulate a score using MC
            obj.score = simulate_from(state); % TODO implement simulation state -> finish
            obj.state = state;
        end
        function value = update(obj)
            % update the this node using MC recursively
            if obj.num_childs == numel(obj.childs)
                % there are to many childs, we have to expand the table
                obj.childs = [obj.childs cell(obj.num_childs,1)];
            end
            if obj.do_exploration() || obj.num_childs == 0
                % explore a potential state
                state_to_explore = obj.explore();

                %check if state has already been visited
                terminate = false;
                idx = 1;
                while idx <= obj.num_childs && ~terminate
                    if obj.childs{idx}.state_equals(state_to_explore)
                        terminate = true;
                    end
                    idx = idx + 1;
                end

                %preform the according action based on search
                if idx > obj.num_childs
                    % state has never been visited
                    % this action terminates the update recursion 
                    % and creates a new leaf
                    obj.num_childs = obj.num_childs + 1;
                    obj.childs{obj.num_childs} = node(state_to_explore);
                    value = obj.childs{obj.num_childs}.calculate_value();
                    obj.update_score(value);
                else
                    % state has been visited at least once
                    value = obj.childs{idx}.update();
                    obj.update_score(value);
                end
            else
                % exploit what we know already
                best_idx = 1;
                for idx = 1:obj.num_childs
                    if obj.childs{idx}.score > obj.childs{best_idx}.score
                        best_idx = idx;
                    end
                end
                value = obj.childs{best_idx}.update();
                obj.update_score(value);
            end
            value = obj.calculate_value();
        end
        function state = explore(obj)
            %select a next state to explore, that may or may not be visited
            %TODO
        end
        function bool = do_exploration(obj)
            % decide if this node should be explored or exploited
            %TODO
        end
        function bool = state_equals(obj, test_state)
            % returns true if the nodes state is equal to test_state
            %TODO
        end
        function update_score(obj, value)
            % updates the score based on some value
            %TODO
        end
        function calculate_value(obj)
            % returns the value of this node to update previous nodes
            %TODO
        end
    end
end
Run Code Online (Sandbox Code Playgroud)

关于代码的一些评论:

  • 根据设置,obj.calculate_value()可能不需要.例如,如果它是某个值,可以通过单独评估孩子的分数来计算
  • 如果a state可以有多个父项,则重用note对象并将其覆盖在结构中是有意义的
  • 因为每个人都node知道它的所有子node节点,所以可以使用根节点轻松生成子树
  • 搜索树(没有任何更新)是一个简单的递归贪婪搜索
  • 根据搜索的分支因素,可能值得一次访问每个可能的子节点(在节点初始化时),然后randsample(obj.childs,1)进行探索,因为这样可以避免复制/重新分配子数组
  • parent当树以递归方式更新时,对属性进行编码,value在完成节点更新后传递给父节点
  • 我重新分配内存的唯一一次是当一个节点有超过50个子节点时,我只对该单个节点进行重新分配

这应该运行得更快,因为它只是担心选择树的任何部分而不接触任何其他部分.


小智 6

我知道这可能听起来很愚蠢...但是如何保持自由节点的数量而不是节点的总数?这将需要与常量(为零)进行比较,这是单一属性访问.

另外一个伏都教的改善将是移动.maxSize靠近.num_nodes,并把这两个之前.Node细胞.像这样,由于.Node属性的增长,它们在内存中的位置不会相对于对象的开头发生变化(这里的巫术是我猜测MATLAB中对象的内部实现).

稍后编辑当我.Node在属性列表末尾移动时,通过扩展.Node属性消耗了大部分执行时间,如预期的那样(5.45秒,相比之下,您提到的比较为1.25秒).