Vit*_*meo 30 c++ algorithm multithreading asynchronous future
我有一个编译时有向异步任务的非循环图.DAG显示了任务之间的依赖关系:通过分析它,可以了解哪些任务可以并行运行(在单独的线程中)以及哪些任务需要等待其他任务在它们开始之前完成(依赖关系).
我要生成从DAG回调链,使用boost::future和.then(...),when_all(...)延续辅助功能.这一代的结果将是一个函数,当被调用时,它将启动回调链并执行DAG所描述的任务,并行运行尽可能多的任务.
但是,我遇到了麻烦,找到了适用于所有情况的通用算法.
我做了一些图纸,以使问题更容易理解.这是一个图例,它将向您展示图纸中的符号含义:
让我们从一个简单的线性DAG开始:
这种依赖关系图包括三项任务(A,B,和C).C取决于B.B取决于A.这里没有并行性的可能性 - 生成算法会构建类似于此的东西:
boost::future<void> A, B, C, end;
A.then([]
    {
        B.then([]
            {
                C.get();
                end.get();
            });
    });
(请注意,所有代码示例都不是100%有效 - 我忽略了移动语义,转发和lambda捕获.)
有许多方法可以解决这个线性DAG:无论是从结束开始还是从开始开始,构建正确的回调链都是微不足道的.
当引入分支和连接时,事情开始变得更加复杂.
这是一个带有fork/join的DAG:
很难想到与此DAG匹配的回调链.如果我尝试向后工作,从最后开始,我的推理如下:
end取决于B和D.(加入)
D取决于C.B并C依赖A.(叉子)可能的链看起来像这样:
boost::future<void> A, B, C, D, end;
A.then([]
    {
        boost::when_all(B, C.then([]
                               {
                                   D.get();
                               }))
            .then([]
                {
                    end.get();
                });
    });
我发现很难手工编写这个链条,我也怀疑它的正确性.我想不出一种实现可以生成这种算法的算法的一般方法 - 由于when_all需要将其参数移入其中,所以还存在额外的困难.
让我们看看最后一个,甚至更复杂的例子:
在这里,我们希望尽可能地利用并行性.考虑任务E:E可以与任何一个并行运行[B, C, D].
这是一个可能的回调链:
boost::future<void> A, B, C, D, E, F, end;
A.then([]
    {
        boost::when_all(boost::when_all(B, C).then([]
                            {
                                D.get();
                            }),
            E)
            .then([]
                {
                    F.then([]
                        {
                            end.get();
                        });
                });
    });
我试图用几种方法提出一种通用算法:
从DAG开始,尝试使用.then(...)continuation 建立链.这不适用于连接,因为目标连接任务将重复多次.
从DAG结束开始,尝试使用when_all(...)continuation 生成链.这会因fors而失败,因为创建fork的节点会重复多次.
显然,"广度优先遍历"方法在这里效果不佳.从我手写的代码示例中,似乎算法需要知道forks和join,并且需要能够正确地混合.then(...)和when_all(...)继续.
以下是我的最后一个问题:
是否始终可以future从任务依赖关系的DAG 生成基于回调的链,其中每个任务在回调链中只出现一次?
如果是这样,在给定任务依赖性DAG构建回调链的情况下,如何实现通用算法?
编辑1:
我们的想法是([dependencies...] -> [dependents...])从DAG 生成地图数据结构,并从该地图生成回调链.
如果len(dependencies...) > 1,那么value是一个连接节点.
如果len(dependents...) > 1,那么key是一个fork节点.
地图中的所有键值对都可以表示为when_all(keys...).then(values...)continuation.
困难的部分是找出正确的顺序,即"扩展" (想想类似于解析器的东西)节点以及如何将fork/join连接连接在一起.
考虑由图像4生成的以下地图.
depenendencies  |  dependents
----------------|-------------
[F]             :  [end]
[D, E]          :  [F]
[B, C]          :  [D]
[A]             :  [E, C, B]
[begin]         :  [A]
通过应用某种类似解析器的减少/传递,我们可以得到一个"干净"的回调链:
// First pass:
// Convert everything to `when_all(...).then(...)` notation
when_all(F).then(end)
when_all(D, E).then(F)
when_all(B, C).then(D)
when_all(A).then(E, C, B)
when_all(begin).then(A)
// Second pass:
// Solve linear (trivial) transformations
when_all(D, E).then(
    when_all(F).then(end)
)
when_all(B, C).then(D)
when_all(
    when_all(begin).then(A)
).then(E, C, B)
// Third pass:
// Solve fork/join transformations
when_all(
    when_all(begin).then(A)
).then(
    when_all(
        E, 
        when_all(B, C).then(D)
    ).then(
        when_all(F).then(end)
    )   
)
第三遍是最重要的一个,但也是一个看起来很难设计算法的.
请注意如何[B, C]有内被发现[E, C, B]列表,以及如何在[D, E]依赖关系列表,D必须解释为的结果when_all(B, C).then(D),并与链接在一起E的when_all(E, when_all(B, C).then(D)).
也许整个问题可以简化为:
给定一个由[dependencies...] -> [dependents...]键值对组成的映射,如何实现将这些对转换为when_all(...)/ .then(...)continuation链的算法?
编辑2:
这是我为上述方法提出的一些伪代码.它似乎适用于我尝试的DAG,但我需要花更多的时间在它上面并"精神上"测试它与其他更棘手的DAG配置.
最简单的方法是从图形的入口节点开始,就像您手动编写代码一样.为了解决这个join问题,你不能使用递归解决方案,你需要对图形进行拓扑排序,然后根据排序构建图形.
这样可以保证在构建节点时,已经创建了所有前一个节点.
为了实现这一目标,我们可以使用DFS,反向排序.
进行拓扑排序后,您可以忘记原始节点ID,并在列表中引用其编号的节点.为此,您需要创建一个编译时间映射,允许使用拓扑排序中的节点索引而不是节点原始节点索引来检索节点前置任务.
编辑:跟进如何在编译时实现拓扑排序,我重构了这个答案.
要在同一页面上,我将假设您的图形如下所示:
struct mygraph
{
     template<int Id>
     static constexpr auto successors(node_id<Id>) ->
        list< node_id<> ... >; //List of successors for the input node
     template<int Id>
     static constexpr auto predecessors(node_id<Id>) ->
        list< node_id<> ... >; //List of predecessors for the input node
     //Get the task associated with the given node.
     template<int Id>
     static constexpr auto task(node_id<Id>);
     using entry_node = node_id<0>;
};
您需要的基本组成部分是node-id的编译时集.在TMP中,集合也是一个列表,仅仅因为set<Ids...>按照Ids事项的顺序.这意味着您可以使用相同的数据结构来编码有关已访问的节点以及同时生成的排序的信息.
/** Topological sort using DFS with reverse-postordering **/
template<class Graph>
struct topological_sort
{
private:
    struct visit;
    // If we reach a node that we already visited, do nothing.
    template<int Id, int ... Is>
    static constexpr auto visit_impl( node_id<Id>,
                                      set<Is...> visited,
                                      std::true_type )
    {
        return visited;
    }
    // This overload kicks in when node has not been visited yet.
    template<int Id, int ... Is>
    static constexpr auto visit_impl( node_id<Id> node,
                                      set<Is...> visited,
                                      std::false_type )
    {
        // Get the list of successors for the current node
        constexpr auto succ = Graph::successors(node);
        // Reverse postordering: we call insert *after* visiting the successors
        // This will call "visit" on each successor, updating the
        // visited set after each step.
        // Then we insert the current node in the set.
        // Notice that if the graph is cyclic we end up in an infinite
        // recursion here.
        return fold( succ,
                     visited,
                     visit() ).insert(node);
        // Conventional DFS would be:
        // return fold( succ, visited.insert(node), visit() );
    }
    struct visit
    {
        // Dispatch to visit_impl depending on the result of visited.contains(node)
        // Note that "contains" returns a type convertible to
        // integral_constant<bool,x>
        template<int Id, int ... Is>
        constexpr auto operator()( set<Is...> visited, node_id<Id> node ) const
        {
            return visit_impl(node, visited, visited.contains(node) );
        }
    };
public:
    template<int StartNodeId>
    static constexpr auto compute( node_id<StartNodeId> node )
    {
        // Start visiting from the entry node
        // The set of visited nodes is initially empty.
        // "as_list" converts set<Is ... > to list< node_id<Is> ... >.
        return reverse( visit()( set<>{}, node ).as_list() );
    }
};
该算法与从最后一个例子的曲线图(假设A = node_id<0>,B = node_id<1>等),产生list<A,B,C,D,E,F>.
这只是一个适配器,它根据给定的顺序修改图中每个节点的Id.因此,假设返回前面的步骤list<C,D,A,B>,这graph_map会将索引0映射到C索引1 D,等等.
template<class Graph, class List>
class graph_map
{   
    // Convert a node_id from underlying graph.
    // Use a function-object so that it can be passed to algorithms.
    struct from_underlying
    { 
        template<int I>
        constexpr auto operator()(node_id<I> id) 
        { return node_id< find(id, List{}) >{}; }
    };
    struct to_underlying
    { 
        template<int I>
        constexpr auto operator()(node_id<I> id) 
        { return get<I>(List{}); }
    };
public:        
    template<int Id>
    static constexpr auto successors( node_id<Id> id )
    {
        constexpr auto orig_id = to_underlying()(id);
        constexpr auto orig_succ = Graph::successors( orig_id );
        return transform( orig_succ, from_underlying() );
    }
    template<int Id>
    static constexpr auto predecessors( node_id<Id> id )
    {
        constexpr auto orig_id = to_underlying()(id);
        constexpr auto orig_succ = Graph::predecessors( orig_id );
        return transform( orig_succ, from_underlying() );
    }
    template<int Id>
    static constexpr auto task( node_id<Id> id )
    {
        return Graph::task( to_underlying()(id) );
    }
    using entry_node = decltype( from_underlying()( typename Graph::entry_node{} ) );
};
我们现在可以按顺序迭代每个节点id.由于我们构建图形图的方式,我们知道对于每个可能的节点,所有前驱的I节点id都小于.II
// Returns a tuple<> of futures
template<class GraphMap, class ... Ts>
auto make_cont( std::tuple< future<Ts> ... > && pred )
{
     // The next node to work with is N:
     constexpr auto current_node = node_id< sizeof ... (Ts) >();
     // Get a list of all the predecessors for the current node.
     auto indices = GraphMap::predecessors( current_node );
     // "select" is some magic function that takes a tuple of Ts
     // and an index_sequence, and returns a tuple of references to the elements 
     // from the input tuple that are in the indices list. 
     auto futures = select( pred, indices );
     // Assuming you have an overload of when_all that takes a tuple,
     // otherwise use C++17 apply.
     auto join = when_all( futures );
     // Note: when_all with an empty parameter list returns a future< tuple<> >,
     // which is always ready.
     // In general this has to be a shared_future, but you can avoid that
     // by checking if this node has only one successor.
     auto next = join.then( GraphMap::task( current_node ) ).share();
     // Return a new tuple of futures, pushing the new future at the back.
     return std::tuple_cat( std::move(pred),
                            std::make_tuple(std::move(next)) );         
}
// Returns a tuple of futures, you can take the last element if you
// know that your DAG has only one leaf, or do some additional 
// processing to extract only the leaf nodes.
template<class Graph>
auto make_callback_chain()
{
    constexpr auto entry_node = typename Graph::entry_node{};
    constexpr auto sorted_list = 
         topological_sort<Graph>::compute( entry_node );
    using map = graph_map< Graph, decltype(sorted_list) >;
    // Note: we are not really using the "index" in the functor here, 
    // we only want to call make_cont once for each node in the graph
    return fold( sorted_list, 
                 std::make_tuple(), //Start with an empty tuple
                 []( auto && tuple, auto index )
                 {
                     return make_cont<map>(std::move(tuple));
                 } );
}
如果可能发生冗余依赖关系,请首先删除它们(请参阅https://mathematica.stackexchange.com/questions/33638/remove-redundant-dependencies-from-a-directed-acyclic-graph).
然后执行以下图形转换(在合并节点中构建子表达式),直到您到达单个节点(以类似于计算电阻网络的方式):

*:其他传入或传出依赖项,具体取决于放置
(...):表达式在单个节点中
Java代码,包括更复杂示例的设置:
public class DirectedGraph {
  /** Set of all nodes in the graph */
  static Set<Node> allNodes = new LinkedHashSet<>();
  static class Node {
    /** Set of all preceeding nodes */
    Set<Node> prev = new LinkedHashSet<>();
    /** Set of all following nodes */
    Set<Node> next = new LinkedHashSet<>();
    String value;
    Node(String value) {
      this.value = value;
      allNodes.add(this);
    }
    void addPrev(Node other) {
      prev.add(other);
      other.next.add(this);
    }
    /** Returns one of the next nodes */
    Node anyNext() {
      return next.iterator().next();
    }
    /** Merges this node with other, then removes other */
    void merge(Node other) {
      prev.addAll(other.prev);
      next.addAll(other.next);
      for (Node on: other.next) {
        on.prev.remove(other);
        on.prev.add(this);
      }
      for (Node op: other.prev) {
        op.next.remove(other);
        op.next.add(this);
      }
      prev.remove(this);
      next.remove(this);
      allNodes.remove(other);
    }
    public String toString() {
      return value;
    }
  }
  /** 
   * Merges sequential or parallel nodes following the given node.
   * Returns true if any node was merged.
   */
  public static boolean processNode(Node node) {
    // Check if we are the start of a sequence. Merge if so.
    if (node.next.size() == 1 && node.anyNext().prev.size() == 1) {
      Node then = node.anyNext();
      node.value += " then " + then.value;
      node.merge(then);
      return true;
    }
    // See if any of the next nodes has a parallel node with
    // the same one level indirect target. 
    for (Node next : node.next) {
      // Nodes must have only one in and out connection to be merged.
      if (next.prev.size() == 1 && next.next.size() == 1) {
        // Collect all parallel nodes with only one in and out connection 
        // and the same target; the same source is implied by iterating over 
        // node.next again.
        Node target = next.anyNext().next();
        Set<Node> parallel = new LinkedHashSet<Node>();
        for (Node other: node.next) {
          if (other != next && other.prev.size() == 1
             && other.next.size() == 1 && other.anyNext() == target) {
            parallel.add(other);
          }
        }
        // If we have found any "parallel" nodes, merge them
        if (parallel.size() > 0) {
          StringBuilder sb = new StringBuilder("allNodes(");
          sb.append(next.value);
          for (Node other: parallel) {
            sb.append(", ").append(other.value);
            next.merge(other);
          }
          sb.append(")");
          next.value = sb.toString();
          return true;
        }
      }
    }
    return false;
  }
  public static void main(String[] args) {
    Node a = new Node("A");
    Node b = new Node("B");
    Node c = new Node("C");
    Node d = new Node("D");
    Node e = new Node("E");
    Node f = new Node("F");
    f.addPrev(d);
    f.addPrev(e);
    e.addPrev(a);
    d.addPrev(b);
    d.addPrev(c);
    b.addPrev(a);
    c.addPrev(a);
    boolean anyChange;
    do {
      anyChange = false;
      for (Node node: allNodes) {
        if (processNode(node)) {
          anyChange = true;
          // We need to leave the inner loop here because changes
          // invalidate the for iteration. 
          break;
        }
      }
      // We are done if we can't find any node to merge.
    } while (anyChange);
    System.out.println(allNodes.toString());
  }
}
输出: A then all(E, all(B, C) then D) then F