DEV Community

Antidisestablishmentarianism
Antidisestablishmentarianism

Posted on • Updated on

Prime Number Generation Part 1: Optimized Trial Division With AVX512 SIMD Intrinsics.

Overly commented...

#include <iostream>
#include <fstream>
#include <immintrin.h>

//number of primes up to exponents of 2, unused but included for reference.
int expected[] = { 1, 2, 4, 6, 11, 18, 31, 54, 97, 172, 309, 564, 1028, 1900, 3512, 6542, 12251, 23000, 43390, 82025,
155611, 295947, 564163, 1077871, 2063689, 3957809, 7603553, 14630843, 28192750, 54400028, 105097565, 203280221 };

// Returns # of primes up to limit and optionally saves them to a file. Limit must be an integer below 4294967295 (UINT32_MAX) and greater than 1;
static int Primes(uint32_t limit, bool save = false)
{
    if (limit < 2)
        throw std::invalid_argument("Limit must be greater than 1.");

    uint32_t total = 0;

    // Allocate memory for expected # of primes. Formula from Pierre Dusart.
    int end = (int)(limit / log(limit) * (1 + 1.2762 / log(limit)));

    // Also ensure prime array length is rounded up to a multiple of 16 and add 1 (since we are starting at primes[1]) to avoid errors when loading vector.
    end = ((end + 15) & ~15) + 1;

    uint32_t* primes = new uint32_t[end];

    // Fill expected array maximum integer value. Needed for one of the conditions for stopping the trial division loop.
    std::fill(primes, primes + end, UINT32_MAX);

    // Seed prime array with first two primes
    primes[0] = 2;
    primes[1] = 3;

    //Create vector of all zeros to compare remainders with.
    __m512i ZeroVector = _mm512_setzero_si512();

    for (uint32_t index = 3; index <= limit; index += 2)
    {
        __mmask16 p = 0;

        // Trial division only requires testing for divisors up to the square root of a number 
        uint32_t sq = sqrt(index);

        // Start pointer at prime number 3 (x[1]) since for loop only iterates over odd numbers.
        uint32_t* x = primes + 1;

        // Create vector filled with contents of current current index.
        __m512i VectorOfIndex = _mm512_set1_epi32(index);

        // Check for remainders (mod) from 16 primes in parallel for each loop iteration.
        while (x[0] != UINT32_MAX && x[0] <= sq && !(p = _mm512_cmpeq_epi32_mask(_mm512_rem_epu32(VectorOfIndex, _mm512_loadu_epi32((__m512i*)x)), ZeroVector)))
            x += 16;

        // If number is prime add it to our prime array to compute further primes.
        if (!p)
        {
            total++;
            primes[total] = index;
        }
    }

    if (save)
    {
        std::ofstream myfile("primes.txt");

        for (int count = 0; count <= total; count++)
            myfile << primes[count] << " ";

        myfile.close();
    }

    delete[] primes;
    return total + 1;
}

int main()
{
    time_t t1 = time(0);
    uint32_t limit = UINT32_MAX - 1;     //Upper lmit is UINT32_MAX - 1.

    try
    {
        int total = Primes(limit, false);

        std::cout << "There are " << total << " primes less than or equal to " << limit << "." << std::endl;
        std::cout << "Computation completed in " << time(0) - t1 << " seconds." << std::endl;
    }
    catch (std::invalid_argument& e)
    {
        std::cerr << e.what() << std::endl;
        return -1;
    }
}
Enter fullscreen mode Exit fullscreen mode

Incidentally, using _mm512_setzero_si512() in the loop instead of a static vector set to all zeros produces no performance penalty in my benchmarks (Compiler optimization?).

Here is a more concise version of the Primes function if you prefer shorter code:

static int Primes2(uint32_t limit)
{
    int end = ((int)(limit / log(limit) * (1 + 1.2762 / log(limit))) + 15) & ~15;
    uint32_t total = 1, * primes = new uint32_t[end];
    std::fill(primes, primes + end - 1, UINT32_MAX);
    primes[0] = 3;

    for (uint32_t index = 3; index <= limit; index += 2)
    {
        __mmask16 p = 0;
        uint32_t sq = sqrt(index), * x = primes;

        while (x[0] != UINT32_MAX && x[0] <= sq && !(p = _mm512_cmpeq_epi32_mask(_mm512_rem_epu32(_mm512_set1_epi32(index), _mm512_loadu_epi32((__m512i*)x)), _mm512_setzero_si512())))
            x += 16;

        if (!p)
        {
            total++;
            primes[total] = index;
        }
    }

    delete[] primes;
    return total;
}
Enter fullscreen mode Exit fullscreen mode

Top comments (0)