VB_*_*VB_ 7 java concurrency multithreading fork-join java-8
我想要的是
我想研究fork/join算法的优化.通过优化,我的意思是计算最佳线程数,或者如果你想要 - 计算SEQUENTIAL_THRESHOLD(参见下面的代码).
// PSEUDOCODE
Result solve(Problem problem) {
if (problem.size < SEQUENTIAL_THRESHOLD)
return solveSequentially(problem);
else {
Result left, right;
INVOKE-IN-PARALLEL {
left = solve(extractLeftHalf(problem));
right = solve(extractRightHalf(problem));
}
return combine(left, right);
}
}
Run Code Online (Sandbox Code Playgroud)
我怎么想象呢
例如,我想计算大数组的乘积.然后我只评估所有组件并获得最佳线程数量:
SEQUENTIAL_THRESHOLD = PC * IS / MC (只是例子)
PC - 处理器核心数量;
IS - 常量,表示具有一个处理器内核的最佳阵列大小和对数据的最简单操作(例如读取);
MC - 倍增运营成本;
假设MC = 15; PC = 4且IS = 10000; SEQUENTIAL_THRESHOLD = 2667.如果子任务数组大于2667,我会分叉它.
广泛的问题
狭义的问题:
是否已经存在关于SEQUENTIAL_THRESHOLD数组/集合/排序计算的一些调查?他们如何实现这一目标?
2014年3月7日更新:
除非您与执行环境密切相关,否则绝对无法计算合适的阈值.我在sourceforge.net上维护一个fork/join项目,这是我在大多数内置函数中使用的代码:
private int calcThreshold(int nbr_elements, int passed_threshold) {
// total threads in session
// total elements in array
int threads = getNbrThreads();
int count = nbr_elements + 1;
// When only one thread, it doesn't pay to decompose the work,
// force the threshold over array length
if (threads == 1) return count;
/*
* Whatever it takes
*
*/
int threshold = passed_threshold;
// When caller suggests a value
if (threshold > 0) {
// just go with the caller's suggestion or do something with the suggestion
} else {
// do something usful such as using about 8 times as many tasks as threads or
// the default of 32k
int temp = count / (threads << 3);
threshold = (temp < 32768) ? 32768 : temp;
} // endif
// whatever
return threshold;
}
Run Code Online (Sandbox Code Playgroud)
3月9日编辑:
你怎么可能有一个通用的工具,不仅可以知道处理器速度,可用内存,处理器数量等(物理环境),还有软件的意图?答案是你不能.这就是为什么你需要为每个环境开发一个例程.上面的方法是我用于基本数组(向量).我使用另一个方法进行大多数矩阵处理:
// When very small, just spread every row
if (count < 6) return 1;
// When small, spread a little
if (count < 30) return ((count / (threads << 2) == 0)? threads : (count / (threads << 2)));
// this works well for now
return ((count / (threads << 3) == 0)? threads : (count / (threads << 3)));
Run Code Online (Sandbox Code Playgroud)
就Java8流而言:它们使用底层的F/J框架,您无法指定阈值.
这是一个非常值得研究的有趣问题。我编写了这个简单的代码来测试顺序阈值的最佳值。但我无法得出任何具体的结论,很可能是因为我在一台只有 2 个处理器的旧笔记本电脑上运行它。多次运行后唯一一致的观察结果是,所用时间迅速下降,直到连续阈值 100。尝试运行此代码,让我知道您发现了什么。另外,在底部我还附加了一个用于绘制结果的 python 脚本,以便我们可以直观地看到趋势。
import java.io.FileWriter;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
public class Testing {
static int SEQ_THRESHOLD;
public static void main(String[] args) throws Exception {
int size = 100000;
int[] v1 = new int[size];
int[] v2 = new int[size];
int[] v3 = new int[size];
for (int i = 0; i < size; i++) {
v1[i] = i; // Arbitrary initialization
v2[i] = 2 * i; // Arbitrary initialization
}
FileWriter fileWriter = new FileWriter("OutTime.dat");
// Increment SEQ_THRESHOLD and save time taken by the code to run in a file
for (SEQ_THRESHOLD = 10; SEQ_THRESHOLD < size; SEQ_THRESHOLD += 50) {
double avgTime = 0.0;
int samples = 5;
for (int i = 0; i < samples; i++) {
long startTime = System.nanoTime();
ForkJoinPool fjp = new ForkJoinPool();
fjp.invoke(new VectorAddition(0, size, v1, v2, v3));
long endTime = System.nanoTime();
double secsTaken = (endTime - startTime) / 1.0e9;
avgTime += secsTaken;
}
fileWriter.write(SEQ_THRESHOLD + " " + (avgTime / samples) + "\n");
}
fileWriter.close();
}
}
class VectorAddition extends RecursiveAction {
int[] v1, v2, v3;
int start, end;
VectorAddition(int start, int end, int[] v1, int[] v2, int[] v3) {
this.start = start;
this.end = end;
this.v1 = v1;
this.v2 = v2;
this.v3 = v3;
}
int SEQ_THRESHOLD = Testing.SEQ_THRESHOLD;
@Override
protected void compute() {
if (end - start < SEQ_THRESHOLD) {
// Simple vector addition
for (int i = start; i < end; i++) {
v3[i] = v1[i] + v2[i];
}
} else {
int mid = (start + end) / 2;
invokeAll(new VectorAddition(start, mid, v1, v2, v3),
new VectorAddition(mid, end, v1, v2, v3));
}
}
}
Run Code Online (Sandbox Code Playgroud)
这是用于绘制结果的 Python 脚本:
from pylab import *
threshold = loadtxt("./OutTime.dat", delimiter=" ", usecols=(0,))
timeTaken = loadtxt("./OutTime.dat", delimiter=" ", usecols=(1,))
plot(threshold, timeTaken)
show()
Run Code Online (Sandbox Code Playgroud)