Multithreaded Segmented Sieve of Eratosthenes in Java

MC *_*tch 7 java arrays primes multithreading sieve-of-eratosthenes

I am trying to create a fast prime generator in Java. It is (more or less) accepted that the fastest way for this is the segmented sieve of Eratosthenes: https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes. Lots of optimizations can be further implemented to make it faster. As of now, my implementation generates 50847534 primes below 10^9 in about 1.6 seconds, but I am looking to make it faster and at least break the 1 second barrier. To increase the chance of getting good replies, I will include a walkthrough of the algorithm as well as the code.

Still, as a TL;DR, I am looking to include multi-threading into the code

For the purposes of this question, I want to separate between the 'segmented' and the 'traditional' sieves of Eratosthenes. The traditional sieve requires O(n) space and therefore is very limited in range of the input (the limit of it). The segmented sieve however only requires O(n^0.5) space and can operate on much larger limits. (A main speed-up is using a cache-friendly segmentation, taking into account the L1 & L2 cache sizes of the specific computer). Finally, the main difference that concerns my question is that the traditional sieve is sequential, meaning it can only continue once the previous steps are completed. The segmented sieve however, is not. Each segment is independent, and is 'processed' individually against the sieving primes (the primes not larger than n^0.5). This means that theoretically, once I have the sieving primes, I can divide the work between multiple computers, each processing a different segment. The work of eachother is independent of the others. Assuming (wrongly) that each segment requires the same amount of time t to complete, and there are k segments, One computer would require total time of T = k * t, whereas k computers, each working on a different segment would require a total amount of time T = t to complete the entire process. (Practically, this is wrong, but for the sake of simplicity of the example).

This brought me to reading about multithreading - dividing the work to a few threads each processing a smaller amount of work for better usage of CPU. To my understanding, the traditional sieve cannot be multithreaded exactly because it is sequential. Each thread would depend on the previous, rendering the entire idea unfeasible. But a segmented sieve may indeed (I think) be multithreaded.

Instead of jumping straight into my question, I think it is important to introduce my code first, so I am hereby including my current fastest implementation of the segmented sieve. I have worked quite hard on it. It took quite some time, slowly tweaking and adding optimizations to it. The code is not simple. It is rather complex, I would say. I therefore assume the reader is familiar with the concepts I am introducing, such as wheel factorization, prime numbers, segmentation and more. I have included notes to make it easier to follow.

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;

public class primeGen {

    public static long x = (long)Math.pow(10, 9); //limit
    public static int sqrtx;
    public static boolean [] sievingPrimes; //the sieving primes, <= sqrtx

    public static int [] wheels = new int [] {2,3,5,7,11,13,17,19}; // base wheel primes
    public static int [] gaps; //the gaps, according to the wheel. will enable skipping multiples of the wheel primes
    public static int nextp; // the first prime > wheel primes
    public static int l; // the amount of gaps in the wheel

    public static void main(String[] args)
    {
        long startTime = System.currentTimeMillis();

        preCalc();  // creating the sieving primes and calculating the list of gaps

        int segSize = Math.max(sqrtx, 32768*8); //size of each segment
        long u = nextp; // 'u' is the running index of the program. will continue from one segment to the next
        int wh = 0; // the will be the gap index, indicating by how much we increment 'u' each time, skipping the multiples of the wheel primes

        long pi = pisqrtx(); // the primes count. initialize with the number of primes <= sqrtx

        for (long low = 0 ; low < x ; low += segSize) //the heart of the code. enumerating the primes through segmentation. enumeration will begin at p > sqrtx
        {
            long high = Math.min(x, low + segSize);
            boolean [] segment = new boolean [(int) (high - low + 1)];

            int g = -1;
            for (int i = nextp ; i <= sqrtx ; i += gaps[g])
            { 
                if (sievingPrimes[(i + 1) / 2])
                {
                    long firstMultiple = (long) (low / i * i);
                    if (firstMultiple < low) 
                        firstMultiple += i; 
                    if (firstMultiple % 2 == 0) //start with the first odd multiple of the current prime in the segment
                        firstMultiple += i;

                    for (long j = firstMultiple ; j < high ; j += i * 2) 
                        segment[(int) (j - low)] = true; 
                }
                g++;
                //if (g == l) //due to segment size, the full list of gaps is never used **within just one segment** , and therefore this check is redundant. 
                              //should be used with bigger segment sizes or smaller lists of gaps
                    //g = 0;
            }

            while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            }
        }

        System.out.println(pi);

        long endTime = System.currentTimeMillis();
        System.out.println("Solution took "+(endTime - startTime) + " ms");
    }

    public static boolean [] simpleSieve (int l)
    {
        long sqrtl = (long)Math.sqrt(l);
        boolean [] primes = new boolean [l/2+2];
        Arrays.fill(primes, true);
        int g = -1;
        for (int i = nextp ; i <= sqrtl ; i += gaps[g])
        {
            if (primes[(i + 1) / 2])
                for (int j = i * i ; j <= l ; j += i * 2)
                    primes[(j + 1) / 2]=false;
            g++;
            if (g == l)
                g=0;
        }
        return primes;
    }

    public static long pisqrtx ()
    {
        int pi = wheels.length;
        if (x < wheels[wheels.length-1])
        {
            if (x < 2)
                return 0;
            int k = 0;
            while (wheels[k] <= x)
                k++;
            return k;
        }
        int g = -1;
        for (int i = nextp ; i <= sqrtx ; i += gaps[g])
        {
            if(sievingPrimes[( i + 1 ) / 2])
                pi++;
            g++;
            if (g == l)
                g=0;
        }

        return pi;
    }

    public static void preCalc ()
    {
        sqrtx = (int) Math.sqrt(x);

        int prod = 1;
        for (long p : wheels)
            prod *= p; // primorial
        nextp = BigInteger.valueOf(wheels[wheels.length-1]).nextProbablePrime().intValue(); //the first prime that comes after the wheel
        int lim = prod + nextp; // circumference of the wheel

        boolean [] marks = new boolean [lim + 1];
        Arrays.fill(marks, true);

        for (int j = 2 * 2 ;j <= lim ; j += 2)
            marks[j] = false;
        for (int i = 1 ; i < wheels.length ; i++)
        {
            int p = wheels[i];
            for (int j = p * p ; j <= lim ; j += 2 * p)
                marks[j]=false;   // removing all integers that are NOT comprime with the base wheel primes
        }
        ArrayList <Integer> gs = new ArrayList <Integer>(); //list of the gaps between the integers that are coprime with the base wheel primes
        int d = nextp;
        for (int p = d + 2 ; p < marks.length ; p += 2)
        {
            if (marks[p]) //d is prime. if p is also prime, then a gap is identified, and is noted.
            {
                gs.add(p - d);
                d = p;
            }
        }
        gaps = new int [gs.size()];
        for (int i = 0 ; i < gs.size() ; i++)
            gaps[i] = gs.get(i); // Arrays are faster than lists, so moving the list of gaps to an array
        l = gaps.length;

        sievingPrimes = simpleSieve(sqrtx); //initializing the sieving primes
    }

}
Run Code Online (Sandbox Code Playgroud)

Currently, it produces 50847534 primes below 10^9 in about 1.6 seconds. This is very impressive, at least by my standards, but I am looking to make it faster, possibly break the 1 second barrier. Even then, I believe it can be made much faster still.

The whole program is based on wheel factorization: https://en.wikipedia.org/wiki/Wheel_factorization. I have noticed I am getting the fastest results using a wheel of all primes up to 19.

public static int [] wheels = new int [] {2,3,5,7,11,13,17,19}; // base wheel primes
Run Code Online (Sandbox Code Playgroud)

This means that the multiples of those primes are skipped, resulting in a much smaller searching range. The gaps between numbers which we need to take are then calculated in the preCalc method. If we make those jumps between the the numbers in the searching range we skip the multiples of the base primes.

public static void preCalc ()
    {
        sqrtx = (int) Math.sqrt(x);

        int prod = 1;
        for (long p : wheels)
            prod *= p; // primorial
        nextp = BigInteger.valueOf(wheels[wheels.length-1]).nextProbablePrime().intValue(); //the first prime that comes after the wheel
        int lim = prod + nextp; // circumference of the wheel

        boolean [] marks = new boolean [lim + 1];
        Arrays.fill(marks, true);

        for (int j = 2 * 2 ;j <= lim ; j += 2)
            marks[j] = false;
        for (int i = 1 ; i < wheels.length ; i++)
        {
            int p = wheels[i];
            for (int j = p * p ; j <= lim ; j += 2 * p)
                marks[j]=false;   // removing all integers that are NOT comprime with the base wheel primes
        }
        ArrayList <Integer> gs = new ArrayList <Integer>(); //list of the gaps between the integers that are coprime with the base wheel primes
        int d = nextp;
        for (int p = d + 2 ; p < marks.length ; p += 2)
        {
            if (marks[p]) //d is prime. if p is also prime, then a gap is identified, and is noted.
            {
                gs.add(p - d);
                d = p;
            }
        }
        gaps = new int [gs.size()];
        for (int i = 0 ; i < gs.size() ; i++)
            gaps[i] = gs.get(i); // Arrays are faster than lists, so moving the list of gaps to an array
        l = gaps.length;

        sievingPrimes = simpleSieve(sqrtx); //initializing the sieving primes
    } 
Run Code Online (Sandbox Code Playgroud)

At the end of the preCalc method, the simpleSieve method is called, efficiently sieving all the sieving primes mentioned before, the primes <= sqrtx. This is a simple Eratosthenes sieve, rather than segmented, but it is still based on wheel factorization, perviously computed.

 public static boolean [] simpleSieve (int l)
    {
        long sqrtl = (long)Math.sqrt(l);
        boolean [] primes = new boolean [l/2+2];
        Arrays.fill(primes, true);
        int g = -1;
        for (int i = nextp ; i <= sqrtl ; i += gaps[g])
        {
            if (primes[(i + 1) / 2])
                for (int j = i * i ; j <= l ; j += i * 2)
                    primes[(j + 1) / 2]=false;
            g++;
            if (g == l)
                g=0;
        }
        return primes;
    } 
Run Code Online (Sandbox Code Playgroud)

Finally, we reach the heart of the algorithm. We start by enumerating all primes <= sqrtx, with the following call:

 long pi = pisqrtx();`
Run Code Online (Sandbox Code Playgroud)

which used the following method:

public static long pisqrtx ()
    {
        int pi = wheels.length;
        if (x < wheels[wheels.length-1])
        {
            if (x < 2)
                return 0;
            int k = 0;
            while (wheels[k] <= x)
                k++;
            return k;
        }
        int g = -1;
        for (int i = nextp ; i <= sqrtx ; i += gaps[g])
        {
            if(sievingPrimes[( i + 1 ) / 2])
                pi++;
            g++;
            if (g == l)
                g=0;
        }

        return pi;
    } 
Run Code Online (Sandbox Code Playgroud)

Then, after initializing the pi variable which keeps track of the enumeration of primes, we perform the mentioned segmentation, starting the enumeration from the first prime > sqrtx:

 int segSize = Math.max(sqrtx, 32768*8); //size of each segment
        long u = nextp; // 'u' is the running index of the program. will continue from one segment to the next
        int wh = 0; // the will be the gap index, indicating by how much we increment 'u' each time, skipping the multiples of the wheel primes

        long pi = pisqrtx(); // the primes count. initialize with the number of primes <= sqrtx

        for (long low = 0 ; low < x ; low += segSize) //the heart of the code. enumerating the primes through segmentation. enumeration will begin at p > sqrtx
        {
            long high = Math.min(x, low + segSize);
            boolean [] segment = new boolean [(int) (high - low + 1)];

            int g = -1;
            for (int i = nextp ; i <= sqrtx ; i += gaps[g])
            { 
                if (sievingPrimes[(i + 1) / 2])
                {
                    long firstMultiple = (long) (low / i * i);
                    if (firstMultiple < low) 
                        firstMultiple += i; 
                    if (firstMultiple % 2 == 0) //start with the first odd multiple of the current prime in the segment
                        firstMultiple += i;

                    for (long j = firstMultiple ; j < high ; j += i * 2) 
                        segment[(int) (j - low)] = true; 
                }
                g++;
                //if (g == l) //due to segment size, the full list of gaps is never used **within just one segment** , and therefore this check is redundant. 
                              //should be used with bigger segment sizes or smaller lists of gaps
                    //g = 0;
            }

            while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            }
        } 
Run Code Online (Sandbox Code Playgroud)

I have also included it as a note, but will explain as well. Because the segment size is relatively small, we will not go through the entire list of gaps within just one segment, and checking it - is redundant. (Assuming we use a 19-wheel). But in a broader scope overview of the program, we will make use of the entire array of gaps, so the variable u has to follow it and not accidentally surpass it:

 while (u <= high)
            {
                if (!segment[(int) (u - low)])
                    pi++;
                u += gaps[wh];
                wh++;
                if (wh == l)
                    wh = 0;
            } 
Run Code Online (Sandbox Code Playgroud)

Using higher limits will eventually render a bigger segment, which might result in a neccessity of checking we don't surpass the gaps list even within the segment. This, or tweaking the wheel primes base might have this effect on the program. Switching to bit-sieving can largely improve the segment limit though.

  • As an important side-note, I am aware that efficient segmentation is one that takes the L1 & L2 cache-sizes into account. I get the fastest results using a segment size of 32,768 * 8 = 262,144 = 2^18. I am not sure what the cache-size of my computer is, but I do not think it can be that big, as I see most cache sizes <= 32,768. Still, this produces the fastest run time on my computer, so this is why it's the chosen segment size.
  • As I mentioned, I am still looking to improve this by a lot. I believe, according to my introduction, that multithreading can result in a speed-up factor of 4, using 4 threads (corresponding to 4 cores). The idea is that each thread will still use the idea of the segmented sieve, but work on different portions. Divide the n into 4 equal portions - threads, each in turn performing the segmentation on the n/4 elements it is responsible for, using the above program. My question is how do I do that? Reading about multithreading and examples, unfortunately, did not bring to me any insight on how to implement it in the case above efficiently. It seems to me, as opposed to the logic behind it, that the threads were running sequentially, rather than simultaneously. This is why I excluded it from the code to make it more readable. I will really appreciate a code sample on how to do it in this specific code, but a good explanation and reference will maybe do the trick too.

Additionally, I would like to hear about more ways of speeding-up this program even more, any ideas you have, I would love to hear! Really want to make it very fast and efficient. Thank you!

900*_*000 1

像这样的示例应该可以帮助您入门。

解决方案概要:

  • 定义包含特定段的数据结构(“任务”);您也可以将所有不可变的共享数据放入其中以获得额外的整洁性。如果您足够小心,您可以将一个公共可变数组以及段限制传递给所有任务,并且仅更新这些限制内的数组部分。这更容易出错,但可以简化连接结果的步骤(AFAICT;YMMV)。
  • 定义一个存储任务计算结果的数据结构(“结果”)。即使您只是更新共享的结果结构,您也可能需要表明该结构的哪一部分到目前为止已更新。
  • 创建一个接受任务、运行计算并将结果放入给定结果队列的 Runnable。
  • 为任务创建一个阻塞输入队列,为结果创建一个队列。
  • 创建一个线程数接近机器核心数的ThreadPoolExecutor。
  • 将所有任务提交给线程池执行器。它们将被安排在池中的线​​程上运行,并将其结果放入输出队列中,不一定按顺序。
  • 等待线程池中的所有任务完成。
  • 清空输出队列并将部分结果连接到最终结果中。

通过将结果连接到读取输出队列的单独任务中,或者甚至通过更新 下的可变共享输出结构,可能(或可能不会)实现额外的加速synchronized,具体取决于连接步骤涉及的工作量。

希望这可以帮助。