使用 ForkJoin 和 Streams 构建自适应网格细化

Ben*_*Ben 5 java concurrency fork-join java-stream

我想在 3D 中构建自适应网格细化。

基本原理如下:

我有一组具有唯一单元 ID 的单元。我测试每个单元,看看它是否需要改进。

  • 如果需要细化,则创建 8 个新的子单元格并将它们添加到单元格列表中以检查细化。
  • 否则,这是一个叶节点,我将其添加到我的叶节点列表中。

我想使用 ForkJoin 框架和 Java 8 流来实现它。我读了这篇文章,但我不知道如何将其应用到我的案例中。

现在,我想到的是:

public class ForkJoinAttempt {
    private final double[] cellIds;

    public ForkJoinAttempt(double[] cellIds) {
        this.cellIds = cellIds;
    }

    public void refineGrid() {
        ForkJoinPool pool = ForkJoinPool.commonPool();
        double[] result = pool.invoke(new RefineTask(100));
    }

    private class RefineTask extends RecursiveTask<double[]> {
        final double cellId;

        private RefineTask(double cellId) {
            this.cellId = cellId;
        }

        @Override
        protected double[] compute() {
            return ForkJoinTask.invokeAll(createSubtasks())
                    .stream()
                    .map(ForkJoinTask::join)
                    .reduce(new double[0], new Concat());
        }
    }

    private double[] refineCell(double cellId) {
        double[] result;
        if (checkCell()) {
            result = new double[8];

            for (int i = 0; i < 8; i++) {
                result[i] = Math.random();
            }

        } else {
            result = new double[1];
            result[0] = cellId;
        }

        return result;
    }

    private Collection<RefineTask> createSubtasks() {
        List<RefineTask> dividedTasks = new ArrayList<>();

        for (int i = 0; i < cellIds.length; i++) {
            dividedTasks.add(new RefineTask(cellIds[i]));
        }
        
        return dividedTasks;
    }

    private class Concat implements BinaryOperator<double[]>  {

        @Override
        public double[] apply(double[] a, double[] b) {
            int aLen = a.length;
            int bLen = b.length;

            @SuppressWarnings("unchecked")
            double[] c = (double[]) Array.newInstance(a.getClass().getComponentType(), aLen + bLen);
            System.arraycopy(a, 0, c, 0, aLen);
            System.arraycopy(b, 0, c, aLen, bLen);

            return c;
        }
    }

    public boolean checkCell() {
        return Math.random() < 0.5;
    }
}
Run Code Online (Sandbox Code Playgroud)

...我被困在这里了。

目前这没有多大作用,因为我从不调用该refineCell函数。

我创建的所有内容也可能存在性能问题double[]。以这种方式合并它们可能也不是最有效的方法。

但首先,任何人都可以帮助我在这种情况下实现分叉连接吗?

该算法的预期结果是叶单元 ID 数组 ( double[])

编辑1:

感谢这些评论,我想出了一些效果更好的东西。

一些变化:

  • 我从数组转到列表。这对内存占用不利,因为我无法使用 Java 原语。但这使植入变得更简单。
  • 单元 ID 现在是长整型而不是双精度。
  • Id 不再是随机选择的:
    • 根级单元格的 ID 为 1、2、3 等;
    • 1 岁的孩子的 ID 为 10、11、12 等;
    • 2 岁儿童的 ID 为 20、21、22 等;
    • 你明白了...
  • 我精炼所有 ID 低于 100 的单元格

这使我能够为了这个示例更好地检查结果。

这是新的实现:

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.*;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class ForkJoinAttempt {
    private static final int THRESHOLD = 2;
    private List<Long> leafCellIds;

    public void refineGrid(List<Long> cellsToProcess) {
        leafCellIds = ForkJoinPool.commonPool().invoke(new RefineTask(cellsToProcess));
    }

    public List<Long> getLeafCellIds() {
        return leafCellIds;
    }

    private class RefineTask extends RecursiveTask<List<Long>> {

        private final CopyOnWriteArrayList<Long> cellsToProcess = new CopyOnWriteArrayList<>();

        private RefineTask(List<Long> cellsToProcess) {
            this.cellsToProcess.addAll(cellsToProcess);
        }

        @Override
        protected List<Long> compute() {
            if (cellsToProcess.size() > THRESHOLD) {
                System.out.println("Fork/Join");
                return ForkJoinTask.invokeAll(createSubTasks())
                        .stream()
                        .map(ForkJoinTask::join)
                        .reduce(new ArrayList<>(), new Concat());
            } else {
                System.out.println("Direct computation");
                
                List<Long> leafCells = new ArrayList<>();

                for (Long cell : cellsToProcess) {
                    Long result = refineCell(cell);
                    if (result != null) {
                        leafCells.add(result);
                    }
                }

                return leafCells;
            }
        }

        private Collection<RefineTask> createSubTasks() {
            List<RefineTask> dividedTasks = new ArrayList<>();

            for (List<Long> list : split(cellsToProcess)) {
                dividedTasks.add(new RefineTask(list));
            }

            return dividedTasks;
        }

        private Long refineCell(Long cellId) {
            if (checkCell(cellId)) {
                for (int i = 0; i < 8; i++) {
                    Long newCell = cellId * 10 + i;
                    cellsToProcess.add(newCell);
                    System.out.println("Adding child " + newCell + " to cell " + cellId);
                }
                return null;
            } else {
                System.out.println("Leaf node " + cellId);
                return cellId;
            }
        }

        private List<List<Long>> split(List<Long> list)
        {
            int[] index = {0, (list.size() + 1)/2, list.size()};

            List<List<Long>> lists = IntStream.rangeClosed(0, 1)
                    .mapToObj(i -> list.subList(index[i], index[i + 1]))
                    .collect(Collectors.toList());

            return lists;
        }


    }



    private class Concat implements BinaryOperator<List<Long>> {
        @Override
        public List<Long> apply(List<Long> listOne, List<Long> listTwo) {
            return Stream.concat(listOne.stream(), listTwo.stream())
                    .collect(Collectors.toList());
        }
    }

    public boolean checkCell(Long cellId) {
        return cellId < 100;
    }
}
Run Code Online (Sandbox Code Playgroud)

以及测试它的方法:

    int initialSize = 4;
    List<Long> cellIds = new ArrayList<>(initialSize);
    for (int i = 0; i < initialSize; i++) {
        cellIds.add(Long.valueOf(i + 1));
    }

    ForkJoinAttempt test = new ForkJoinAttempt();
    test.refineGrid(cellIds);
    List<Long> leafCellIds = test.getLeafCellIds();
    System.out.println("Leaf nodes: " + leafCellIds.size());
    for (Long node : leafCellIds) {
        System.out.println(node);
    }
Run Code Online (Sandbox Code Playgroud)

输出确认它向每个根单元添加了 8 个子单元。但它并没有更进一步。

我知道为什么,但我不知道如何解决它:这是因为即使细化单元方法将新单元格添加到要处理的单元格列表中。createSubTask 方法不会再次被调用,因此它无法知道我添加了新单元格。

编辑2:

为了以不同的方式陈述问题,我正在寻找一种机制,其中一些Queue单元格 ID 由某些单元格处理RecursiveTask,而其他单元格 ID 则并行添加到单元格 IDQueue中。

Hol*_*ger 3

首先让\xe2\x80\x99s从基于Stream的解决方案开始

\n\n
public class Mesh {\n    public static long[] refineGrid(long[] cellsToProcess) {\n        return Arrays.stream(cellsToProcess).parallel().flatMap(Mesh::expand).toArray();\n    }\n    static LongStream expand(long d) {\n        return checkCell(d)? LongStream.of(d): generate(d).flatMap(Mesh::expand);\n    }\n    private static boolean checkCell(long cellId) {\n        return cellId > 100;\n    }\n    private static LongStream generate(long cellId) {\n        return LongStream.range(0, 8).map(j -> cellId * 10 + j);\n    }\n}\n
Run Code Online (Sandbox Code Playgroud)\n\n

虽然当前的flatMap实施存在已知问题,当网格过于不平衡时可能会出现这些问题,但实际任务的性能可能是合理的,因此在开始实现更复杂的东西之前,这个简单的解决方案总是值得一试。

\n\n

如果您确实需要自定义实现,例如,如果工作负载不平衡并且 Stream 实现不能很好地适应\xe2\x80\x99,您可以这样做:

\n\n
public class MeshTask extends RecursiveTask<long[]> {\n    public static long[] refineGrid(long[] cellsToProcess) {\n        return new MeshTask(cellsToProcess, 0, cellsToProcess.length).compute();\n    }\n    private final long[] source;\n    private final int from, to;\n\n    private MeshTask(long[] src, int from, int to) {\n        source = src;\n        this.from = from;\n        this.to = to;\n    }\n    @Override\n    protected long[] compute() {\n        return compute(source, from, to);\n    }\n    private static long[] compute(long[] source, int from, int to) {\n        long[] result = new long[to - from];\n        ArrayDeque<MeshTask> next = new ArrayDeque<>();\n        while(getSurplusQueuedTaskCount()<3) {\n            int mid = (from+to)>>>1;\n            if(mid == from) break;\n            MeshTask task = new MeshTask(source, mid, to);\n            next.push(task);\n            task.fork();\n            to = mid;\n        }\n        int pos = 0;\n        for(; from < to; ) {\n            long value = source[from++];\n            if(checkCell(value)) result[pos++]=value;\n            else {\n                long[] array = generate(value);\n                array = compute(array, 0, array.length);\n                result = Arrays.copyOf(result, result.length+array.length-1);\n                System.arraycopy(array, 0, result, pos, array.length);\n                pos += array.length;\n            }\n            while(from == to && !next.isEmpty()) {\n                MeshTask task = next.pop();\n                if(task.tryUnfork()) {\n                    to = task.to;\n                }\n                else {\n                    long[] array = task.join();\n                    int newLen = pos+to-from+array.length;\n                    if(newLen != result.length)\n                        result = Arrays.copyOf(result, newLen);\n                    System.arraycopy(array, 0, result, pos, array.length);\n                    pos += array.length;\n                }\n            }\n        }\n        return result;\n    }\n    static boolean checkCell(long cellId) {\n        return cellId > 1000;\n    }\n    static long[] generate(long cellId) {\n        long[] sub = new long[8];\n        for(int i = 0; i < sub.length; i++) sub[i] = cellId*10+i;\n        return sub;\n    }\n}\n
Run Code Online (Sandbox Code Playgroud)\n\n

此实现compute直接调用根任务的方法,将调用者线程合并到计算中。该compute方法用于getSurplusQueuedTaskCount()决定是否分裂。正如其文档所述,这个想法是始终有少量盈余,例如3。这确保了评估可以适应不平衡的工作负载,因为空闲线程可以从其他任务中窃取工作。

\n\n

拆分不是通过创建两个子任务并等待两个子任务来完成的。相反,仅拆分一个任务,代表待处理工作的后半部分,并且当前任务\xe2\x80\x99s 工作负载会进行调整以反映前半部分。

\n\n

然后,剩余的工作负载在本地处理。之后,最后推送的子任务被弹出并尝试取消分叉。如果取消分叉成功,则当前工作负载\xe2\x80\x99s范围也将适应覆盖后续任务\xe2\x80\x99s范围,并且本地迭代继续。

\n\n

这样,任何未被其他线程窃取的剩余任务都会以最简单、最轻量的方式处理,就好像它从未被分叉一样。

\n\n

如果任务已被另一个线程接收,我们现在必须等待其完成并合并结果数组。

\n\n

请注意,当通过 等待子任务时join(),底层实现还将检查是否可以取消分叉和本地评估,以使所有工作线程保持忙碌。但是,调整循环变量并直接将结果累加到目标数组中仍然比compute仍然需要合并结果数组的嵌套调用要好。

\n\n

如果单元不是叶子,则结果节点将通过相同的逻辑递归处理。这再次允许自适应本地和并发评估,因此执行将适应不平衡的工作负载,例如,如果特定单元具有更大的子树或特定单元任务的评估比其他单元任务长得多。

\n\n

必须强调的是,在所有情况下,都需要大量的处理工作量才能从并行处理中获益。如果像示例中那样,主要仅进行数据复制,那么好处可能会小得多,甚至不存在,或者在最坏的情况下,并行处理的性能可能比顺序处理更差。

\n

  • 对于数值变量,您可以使用,例如`LongStream.range(0, 8).flatMap(i -&gt; LongStream.range(0, 8).flatMap(j -&gt; LongStream.range(0, 8).map(k - &gt; cellId*1000+i*100+j*10+k)))` (2认同)