Miller, Rabin, vector

Testing small numbers for primality is a popular subtask in sports programming. And the Miller-Rabin test is perhaps the most popular of the simple algorithms for this.

I have long had a desire to play with it, trying to optimize it in various ways. For example, vectorize and see if it gets faster.

Disclaimer. I am a Java developer, I mainly wrote in C++ only at university. So in some places my code could suddenly turn out strange.
Let me note right away that I had no goal of writing portable code. My goal was to have fun and see what I could do, not to offer the world a one-size-fits-all solution.

What problem are we solving?

Goal: be able to understand whether a signed 32-bit integer is prime, and preferably quickly.

The optimal solution for checking speed in this case was to calculate in advance all the prime numbers that meet our requirements and put them in an array/hash table to then search there. In total there would be about 105 million. I won't do this to avoid long preliminary calculations.

In general, I would like to have an analogue of such a function, only faster:

bool testNaive(int32_t n) {
    if (n <= 2) return n == 2;

    for (int32_t m = 2, boundary = sqrt(n); m <= boundary; m++) {
        if (n % m == 0) return false;
    }

    return true;
}

The code above is a classic primality test that iterates through potential candidates for divisors of a number. n. He works forO(\sqrt n)that is, slowly.

Miller-Rabin test

Miller-Rabin test – a probability test for checking numbers for simplicity. But knowing in advance the upper limit for the number being tested, it is easy to turn this test into a deterministic one, more on that later.

Number primality test n consists of several rounds. At each round we receive one of two answers:
– a composite number;
– the number can be prime, who knows.
The more rounds we get the answer “maybe prime”, the greater the chance that the number is indeed prime.

The round itself consists of the following actions:
– take a number 1 < a < nyes, one that has not yet been used in previous rounds;
– check if it is a “witness to the simplicity” of the number n. If not, then n – composite, and if it is, then n is called strongly pseudosimple modulo a.

Next, I will use variable names as on Wikipedia to make it easier to compare. Verification of evidence of primality in the Miller-Rabin test occurs using the following algorithm:

  • First of all, you need to weed out even numbers, only odd numbers are allowed n.

  • n needs to be presented in the form 2^s\cdot{t}+1. It is worth noting that s always more 0.

  • Run the following code. Outwardly, it differs from the pseudocode from Wikipedia, but it does exactly the same thing. true means strong pseudo-simplicity n modulo a:

int32_t s = __builtin_ctz(n - 1); // Число нулей в младших битах.
int32_t t = (n - 1) >> s;

int32_t x = pow_mod(a, t, n);     // Возведение в степень по модулю n.
if (x == 1) return true;

for (int i = 1; x != n - 1; i++) {
    if (i == s) return false;

    x = mul_mod(x, x, n);         // Возведение в квадрат по модулю n.

    if (x == 1) return false;
}

return true;

On what basis do you choose numbers? a? Personally, I will do it very simply, although I’m not sure it’s effective. Wikipedia has links to lists of strongly pseudoprimes modulo 2 And modulo 3so we could do a check for 2 and 3, and among the numbers that pass both checks, weed out those that are contained in both lists…

But no, in fact there are such numbers quite a lot. But if you take 2, 3 and 5, then in the range from 1 to 231 only 4 left strongly pseudoprimes:

  • 25326001

  • 161304001

  • 960946321

  • 1157839381

You can easily hardcode them, resulting in the following code:

bool testMillerRabinInteger(int32_t n) {
    if ((n & 1) == 0) return n == 2;
    if (n < 9) return n > 1; // 3, 5 и 7.

    int32_t s = __builtin_ctz(n - 1);
    int32_t t = (n - 1) >> s;

    int32_t primes[3] = {2, 3, 5};

    for (int32_t a : primes) {
        int32_t x = pow_mod(a, t, n);

        if (x == 1) continue;

        for (int i = 1; x != n - 1; i++) {
            if (i == s) return false;

            x = mul_mod(x, x, n);

            if (x == 1) return false;
        }
    }

    switch (n) {
        case 25326001:
        case 161304001:
        case 960946321:
        case 1157839381:
            return false;

        default:
            return true;
    }
}

This is already a completely deterministic test of simplicity, take it and use it! It remains to describe the ones mentioned here mul_mod And pow_mod.

For a modulo product, it is sufficient to expand both the multipliers and the modulus to 64-bit integers to protect against overflow in the product, resulting in
(int32_t) (((int64_t) a * b) % m).

To raise to a power we will use the variety quick exponentiationa little easier to implement:

int32_t pow_mod(int64_t n, int32_t power, int64_t m) {
    int64_t result = 1;

    while (power) {
        if (power & 1) result = (result * n) % m;

        n = (n * n) % m;

        power >>= 1;
    }

    return (int32_t) result;
}

Turn on -O3, we test – everything works! We hope that the compiler has optimized this code well and we won’t have to finish anything ourselves (if only…).

The base is ready, the introduction is complete. Let's move on to the tasty stuff.

What is vectorization

Vectorization of code is its parallelization at the data level. This is roughly what video cards do – one thread of execution, but many independent sets of data processed simultaneously.

For CPUs, vectorization is implemented by special sets of instructions, for example SSE And AVX. Usually programmers don't have to worry about this – the compiler should do all the work, the main thing is to provide it with optimization keys. The main goal of compiler vectorization is loops over large data sets.

In our case, there are no large loops, and there is little data, but there is a loop body that is executed up to 3 times for 3 different parameter values a. The compiler itself will not vectorize such patterns, don’t even hope for it. My goal is to remove this loop by making “normal” code that performs calculations on all the required values a simultaneously.

In total, you need to solve 2 subtasks:

Because on average s much less t (s=O(\log n),\:t = O(n)), then we can say that exponentiation should take much more execution time, which means you should focus on it first.

Vector types and instructions

More or less modern x86 CPUs have 2 sets of vector registers – 128-bit XMM and 256-bit YMM. There are also 512-bit registers in AVX-512but they are not that common.

We will not work directly with registers. C++ is a high-level language, it has special data types for vectors, as well as library functions on these types that correspond to processor instructions. Such functions are called intrinsics. To read more about them, you will have to turn on the VPN, because documentation prohibited on the territory of the Russian Federation (by Intel itself, no extremism).

For the purposes of this article, I will compile GCC with version 17 of the language (if this is suddenly important). By connecting immintrin.hyou can access many types __m128* And __m256*corresponding to different types of vectors, for example:

  • __m128i is a vector of 4 values ​​of type int32_t;

  • __m256i is a vector of 4 values ​​of type int64_t;

  • __m256d is a vector of 4 values ​​of type double.

Here you can see the first advantage of vectors over our loop – they are able to simultaneously process not 3 data units, but 4. This is great, we can add a 4th value without any losses afor example, checking 2, 3, 5 and 7. If again look into OEISthen it will be clear that all strongly pseudoprime numbers in these modules, up to 2147483648 (231), are truly simple, i.e. That switch in the end it will no longer be needed!

For vectorized exponentiation, we need to find the product and remainder functions, everything should be simple. We remember that both of these operations were performed on 64-bit integers, which means we will work with __m256i.

And here we fall into the first serious trap:

  • _mm256_mullo_epi64 – intrinsic for multiplying long integers, does not compile:
    error: inlining failed in call to 'always_inline' '__m256i _mm256_mullo_epi64(__m256i, __m256i)': target specific option mismatch.

    Oops, it seems it is only available for AVX-512. My processor can't do this. So it’s not surprising, and I only compiled with -mavx -mavx2.

  • In principle, there is no vectorized calculation of the remainder of division. And the usual division of long integers, called _mm256_div_epi64, which the documentation gave me – this is not even a CPU instruction, but a custom function that is generally absent in GCC. We'll have to improvise.

I'm not the first person to wonder about vectorized division of integers. The universal answer that people give on the Internet is to use floating-point division as long as you have enough precision. And when it stops working, write additional error correction code. So let's take a moment and think about how much accuracy this will give us. doubleif you decide to use it.

Floating point numbers

How type numbers work double? Probably not everyone remembers it well, so let’s refresh this information. Any such number is presented in the form of 3 parts:

  • 1 bit for sign;

  • 52 bits for “mantissa”;

  • 11 bits for “order”, or floating point offset.

double – this is 53 significant bits (the mantissa + the one that is always present next to it) and a known floating point position, which can be far beyond the significant bits. If we try to encode only integers, it becomes clear that any 53-bit unsigned number (54-bit signed number, respectively) can be encoded without loss of precision.

64-bit int64_t we used it to avoid overflow when multiplying 32-bit integers. Which means double can be used to avoid overflow when multiplying 26-bit numbers (53 / 2). If we limit ourselves to checking for primality only 26-bit numbers, then this will be easily achieved with floating point arithmetic. That's what we'll do. We will work with quadruples of values ​​of the type doublethat is, with __m256d.

Exponentiation modulo

Multiplication problem solved. What about division with a remainder? For it we will have to use the following formula:

a\pmod{m} = a - \lfloor \frac{a}{m} \rfloor \cdot m

By implementing this formula, we get the following function for calculating the product modulo, for 4 values ​​of type at once double (in the comments – non-vectorized analogue):

__m256d mul_mod(__m256d a, __m256d b, __m256d m) {
    __m256d c = _mm256_mul_pd(a, b);     // double c = a * b;

    __m256d tmp = _mm256_div_pd(c, m);   // double tmp = c / m;
    tmp = _mm256_floor_pd(tmp);          // tmp = floor(tmp);
    tmp = _mm256_mul_pd(tmp, m);         // tmp = tmp * m;

    return _mm256_sub_pd(c, tmp);        // return c - tmp;
}

The exponentiation code will hardly differ from the non-vectorized version:

__m256d pow_mod(__m256d n, int32_t power, __m256d m) {
    __m256d result = _mm256_set1_pd(1);  // result = {1, 1, 1, 1};

    while (power) {
        if (power & 1) result = mul_mod(result, n, m);

        n = mul_mod(n, n, m);

        power >>= 1;
    }

    return result;
}

Loop with squares

If we now try to put together what we have, we will get the following code:

bool testMillerRabinVectorized0(int32_t n) {
    if ((n & 1) == 0) return n == 2;
    if (n < 9) return n > 1;

    int32_t s = __builtin_ctz(n - 1);
    int32_t t = (n - 1) >> s;

    __m256d primes = _mm256_set_pd(2, 3, 5, 7);

    __m256d n_pd = _mm256_set1_pd(n);
    __m256d x = pow_mod(primes, t, n_pd);

    // ?
}

In place of the question, there used to be a cycle, within which the following expressions were first encountered: return / continueand secondly for different a there may be a different number of iterations:

if (x == 1) continue;

for (int i = 1; x != n - 1; i++) {
    if (i == s) return false;

    x = mul_mod(x, x, n);

    if (x == 1) return false;
}

We need to bring this code into a form suitable for vectorization. Paradoxically, to do this we need to throw away the exit condition at x == 1 inside the loop, we will leave only the output by the number of iterations and by x == n - 1. I'll explain now.

Let's do this: we will assign x = 0if it is pseudoprime modulo a. If x will become equal 0 for all athen the number n – simple, otherwise – no. Let's describe this with the following pseudocode:

if (x == 1)     x = 0;
if (x == n - 1) x = 0;

if (x == 0 forall a)     // x - сильно псевдопростое по всем модулям a.
    return true;

for (int i = 1; i < s; i++) {
    x = mul_mod(x, x, n);

    if (x == n - 1) x = 0;

    if (x == 0 forall a) // x - сильно псевдопростое по всем модулям a.
        return true;
}

return false;

If at some iteration x became equal 1then he will stay 1 and in the next iterations, because 1 * 1 = 1. Same with 0because 0 * 0 = 0. Eventually return false on the last line will work if:

These are exactly the exit conditions that we need. A vectorized version of this code would look like this:

__m256d n_minus_one = _mm256_set1_pd(n - 1);

x = blend_zero(x, _mm256_set1_pd(1)); // if (x == 1)     x = 0;
x = blend_zero(x, n_minus_one);       // if (x == n - 1) x = 0;

if (all_zero(x)) return true;

for (int i = 1; i < s; i++) {
    x = mul_mod(x, x, n_pd);

    x = blend_zero(x, n_minus_one);   // if (x == n - 1) x = 0;

    if (all_zero(x)) return true;
}

return false;

What is he doing all_zero should be clear from the name. blend_zero does a component-by-component comparison of the elements of the first vector with the elements of the second vector and places in the first 0 where it matches, otherwise leaves the original value where it does not match. These methods are implemented as follows:

const __m256d ZERO = _mm256_setzero_pd();

bool all_zero(__m256d a) {
    __m256d mask_pd = _mm256_cmp_pd(a, ZERO, _CMP_NEQ_OQ);

    return 0 == _mm256_movemask_pd(mask_pd);
}

__m256d blend_zero(__m256d a, __m256d b) {
    __m256d mask_pd = _mm256_cmp_pd(a, b, _CMP_EQ_OQ);

    return _mm256_blendv_pd(a, ZERO, mask_pd);
}

Mask type __m256d contains single bits (64 pieces, for the entire double) in those elements for which the condition in cmp fulfilled (in particular _CMP_NEQ_OQ or _CMP_EQ_OQ).

movemask converts the vector mask into a regular bit mask, where each element will correspond to exactly one bit, and not 64. That is all_zero literally checks that “there are no elements not equal to 0”. Yes, double negative, so what?

_mm256_blendv_pd – makes a vector of two, taking the element of the second vector, if in the mask 1or the first vector if in a mask 0. I think it should be clear, and if it’s not clear, then there is pseudocode in the documentation.

It seems like the first version should be ready:

bool testMillerRabinVectorized(int32_t n) {
    if ((n & 1) == 0) return n == 2;
    if (n < 9) return n > 1;

    int32_t s = __builtin_ctz(n - 1);
    int32_t t = (n - 1) >> s;

    __m256d primes = _mm256_set_pd(2, 3, 5, 7);

    __m256d n_pd = _mm256_set1_pd(n);
    __m256d x = pow_mod(primes, t, n_pd);

    __m256d n_minus_one = _mm256_set1_pd(n - 1);

    x = blend_zero(x, _mm256_set1_pd(1));
    x = blend_zero(x, n_minus_one);

    if (all_zero(x)) return true;

    for (int i = 1; i < s; i++) {
        x = mul_mod(x, x, n_pd);

        x = blend_zero(x, n_minus_one);

        if (all_zero(x)) return true;
    }

    return false;
}

Comparing performance

My testing methodology will be as stupid as possible. We take a large array of integers and for each of them we check whether it is prime. I got the following code:

void measure(std::function<bool(int32_t)> test, uint32_t count) {
    auto start = std::chrono::high_resolution_clock::now();

    int32_t n = 0;

    for (uint32_t i = 1; i < count; i++) {
        n += test(i);
    }

    auto end = std::chrono::high_resolution_clock::now();

    std::cout << (end - start).count() * 1e-9d << " seconds" << std::endl;
    std::cout << n << " primes found" << std::endl << std::endl;
}

Let's run it for the existing implementations, remembering that our upper limit is 226and fingers crossed:

std::cout << "testNaive:" << std::endl;
measure(&testNaive, 1 << 26);

std::cout << "testMillerRabinInteger:" << std::endl;
measure(&testMillerRabinInteger, 1 << 26);

std::cout << "testMillerRabinVectorized:" << std::endl;
measure(&testMillerRabinVectorized, 1 << 26);
testNaive:
33.9298 seconds
3957809 primes found

testMillerRabinInteger:
4.77004 seconds
3957809 primes found

testMillerRabinVectorized:
4.74518 seconds
3957809 primes found

Reality is full of disappointments, a miracle did not happen. Both functions operate at approximately the same speed, within the error. The vectorized version may be a little faster on average, but that doesn't matter.

Amazing coincidence! This is also not bad news – manually calculating the remainder of a division turns out to be not too slow. Will I stop there? Of course not, I need to definitely outperform non-vectorized code, so let's start tuning.

Superscalarity

I’ll start with optimization, which will not bring the biggest profit, but it is very indicative and there is nowhere else to insert it.

Superscalarity is also parallelization at the data level, but not quite explicit in the code, but implicit in the processor. The CPU core is capable of executing multiple instructions simultaneously, provided they are independent. So code that processes independent data is able to run faster than code that processes dependent data.

Let's look at an example – the function pow_mod for integers, I will repeat her code:

int32_t pow_mod(int64_t n, int32_t power, int64_t m) {
    int64_t result = 1;

    while (power) {
        if (power & 1) result = (result * n) % m;

        n = (n * n) % m;

        power >>= 1;
    }

    return (int32_t) result;
}

This code has independent calculations on result And n, but the problem is that they are in different scopes, so it would be difficult for both the compiler and the processor to do something about it. What if scopes were combined?

int32_t pow_mod(int64_t n, int32_t power, int64_t m) {
    int64_t result = 1;

    while (power) {
        if (power & 1) {
            result = (result * n) % m;
            n = (n * n) % m;
        } else {
            n = (n * n) % m;
        }

        power >>= 1;
    }

    return (int32_t) result;
}

Let's run the test and check. For testMillerRabinInteger before optimization it was 4.78585 seconds, after it became 4.25464 seconds, that is, noticeably faster. Although, maybe it’s not just a matter of superscalarity, maybe the branch predictor also began to behave differently, who knows. The main thing is that the code has become faster.

So stop, what have I done, I should have optimized the vectorized code, now everything has only gotten worse! Quickly do the same for the vectorized version pow_modrubbing his hands:

__m256d pow_mod(__m256d n, int32_t power, __m256d m) {
    __m256d result = _mm256_set1_pd(1);

    while (power) {
        if (power & 1) {
            result = mul_mod(result, n, m);
            n = mul_mod(n, n, m);
        } else {
            n = mul_mod(n, n, m);
        }

        power >>= 1;
    }

    return result;
}

And… no noticeable difference. At all.

Why? Maybe the compiler didn't inline the call mul_mod? It seems not, if you do manual inlining, the performance does not change. How it was approximately 4.7 seconds, it remains that way.

Maybe the compiler didn't want to reorder the instructions in the code during compilation? Let's check by manually grouping independent operations:

__m256d pow_mod(__m256d n, int32_t power, __m256d m) {
    __m256d result = _mm256_set1_pd(1);

    while (power) {
        if (power & 1) {
            __m256d c1 = _mm256_mul_pd(result, n);
            __m256d c2 = _mm256_mul_pd(n, n);

            __m256d tmp1 = _mm256_div_pd(c1, m);
            __m256d tmp2 = _mm256_div_pd(c2, m);

            tmp1 = _mm256_floor_pd(tmp1);
            tmp2 = _mm256_floor_pd(tmp2);

            tmp1 = _mm256_mul_pd(tmp1, m);
            tmp2 = _mm256_mul_pd(tmp2, m);

            result = _mm256_sub_pd(c1, tmp1);
            n = _mm256_sub_pd(c2, tmp2);
        } else {
            n = mul_mod(n, n, m);
        }

        power >>= 1;
    }

    return result;
}

Incredible, now I get approximately 4.224.28 seconds, approximately the same as in the superscalar version of the non-vectorized algorithm! This means that the compiler is not always able to effectively reorder instructions. Especially if we are talking about intrinsics, but this is more of an assumption.

Both versions of the algorithm are now slightly faster, but there is still no clear leader among them.

Division and multiplication

What's easier, multiplying or dividing? Obviously, multiply. But our algorithm has a lot of divisions, we need to reduce their number. To do this, let us remember a very banal property:

\frac{a}{b} = a \cdot \frac{1}{b}

Division is multiplication by its reciprocal. If you look at the code carefully, you will see that we always divide by the same vector, namely n_pd (within additional functions it appears under the name m). Instead, I suggest entering a second value:

__m256d n_pd  = _mm256_set1_pd(n);       // m
__m256d n_inv = _mm256_set1_pd(1.0 / n); // m_inv

And every time it meets _mm256_div_pd(c, m)write instead
_mm256_mul_pd(c, m_inv). Intuition says it should get faster, but by how much?

Let's do it and launch it. Before correction – 4.27096 seconds, after – 3.05808. The difference is almost 30%, impressive! Indeed, division is very expensive.

Do you know which algorithm also uses division? testMillerRabinInteger. Maybe if you do multiplication in it, it will become faster?

double mod_double(double n, double m, double m_inv) {
    double tmp = n * m_inv;
    tmp = floor(tmp);
    return n - tmp * m;
}

double pow_mod_double(double n, int32_t power, double m, double m_inv) {
    double result = 1;

    while (power) {
        if (power & 1) result = mod_double(result * n, m, m_inv);

        n = mod_double(n * n, m, m_inv);

        power >>= 1;
    }

    return result;
}

bool testMillerRabinDouble(int32_t n) {
    if ((n & 1) == 0) return n == 2;
    if (n < 9) return n > 1;

    int32_t s = __builtin_ctz(n - 1);
    int32_t t = (n - 1) >> s;

    double m = n,
           m_inv = 1.0 / m;

    int32_t primes[3] = {2, 3, 5};

    for (int32_t a : primes) {
        double x = pow_mod_double(a, t, m, m_inv);

        if (x == 1.0) continue;

        for (int i = 1; x != n - 1; i++) {
            if (i == s) return false;

            x = mod_double(x * x, m, m_inv);

            if (x == 1.0) return false;
        }
    }

    switch (n) {
        case 25326001:    // Остальные убрал, потому что они больше чем 1<<26.
            return false;

        default:
            return true;
    }
}

Testing:

testMillerRabinInteger:
4.2803 seconds
3957809 primes found

testMillerRabinDouble:
4.00873 seconds
3957809 primes found

What a twist, it turned out faster for non-vectorized code! That is, this is how integer * And % long, if they can be replaced by real ones *, *, flooragain * And -and will still be more effective.

I have a hypothesis. When multiplying 64-bit integers, the processor actually returns us a 128-bit product, split into 2 registers (the high and low halves), one of which we don't need. Similarly with division: the processor puts the quotient in one register and the remainder in the other. He just does a lot of side work that we don't need. I'm sure there are other reasons.

In total, it has a vectorized code that checks “all” numbers in approximately 3 seconds, and non-vectorized code that does the same thing in about 4 seconds. It seems like a victory. For 26-bit numbers.

How to multiply if double is not enough

What to do for numbers that have 27 to 31 significant bits? Is it possible to vectorize their primality check without resorting to AVX-512? Of course you can, but I’ll say right away that you shouldn’t count on the effectiveness of this solution. Nevertheless, I'm implementing it to at least find out how bad it will turn out.

The whole task comes down to implementing the function mul_mod, which can multiply 31-bit numbers modulo another 31-bit number, i.e. stop being limited to 26 bits, as in the current implementation:

__m256d mul_mod(__m256d a, __m256d b, __m256d m, __m256d m_inv) {
    __m256d c = _mm256_mul_pd(a, b);

    __m256d tmp = _mm256_mul_pd(c, m_inv);
    tmp = _mm256_floor_pd(tmp);
    tmp = _mm256_mul_pd(tmp, m);

    return _mm256_sub_pd(c, tmp);
}

Before I continue, I want to take a moment to remember one of my favorite subproblems in sports programming. It's important here, though:

Given 2 numbers a And b type int64_tyou need to find their product modulo 1012. Preferably effective.

I suggest everyone take a break and think – how would you write such an algorithm? Just in case, let me remind you that the upper limit of signed 64-bit numbers is approximately 9*1018this fact is easy to forget.

I would solve it as follows. Without loss of generality, let us assume that aAnd b less than 1012 (and non-negative). If this is not the case, then the corresponding division remainders can be taken instead. In this case, the value b can always be represented in the form
b = b1 * 1000000 + b2Where b1 < 1000000 And b2 < 1000000.

What’s good about a million is that when multiplying a number <1012 and numbers <106 we get the number <1018, that is, placed in a signed 64-bit. Such works can be done in code without any fear, the main thing is to correctly transform the formula for a * b:

(a \cdot b) \pmod{10^{12}}\\ = (a \cdot (b_1\cdot 10^6 + b_2)) \pmod{10^{12}}\\ = (a \cdot b_1 \cdot 10^6) \pmod{10^{12}} + (a \cdot b_2) \pmod{10^{12}} \\ = ((a\cdot b_1) \pmod{10^{12}} \cdot 10^6)\pmod{10^{12}} + (a \cdot b_2) \pmod{10^{12}} \\ = ((a\cdot b_1) \pmod{10^{12}} \cdot 10^6 + (a \cdot b_2)) \pmod{10^{12}}

In this formula, none of the products overflows int64_twhich means it is suitable for implementation in code, and it will require only a few lines.

What does this have to do with the problem being solved? Direct. We similarly need to multiply two 32-bit numbers, but there are only 53 bits to calculate the product. And I propose to split one of the factors (b) for a couple b_hi And b_lo – the most significant and least significant 16 bits of a number, if you represent it in the form int32_t. And instead of a million then there will be a constant 1 << 16. In this case, each product will be no more than 48-bit, that is, it can be calculated without loss using the product in double.

Full vectorization

Let's introduce an auxiliary function modwhose body should already be familiar to us:

__m256d mod(__m256d a, __m256d m, __m256d m_inv) {
    __m256d tmp = _mm256_mul_pd(a, m_inv);
    tmp = _mm256_floor_pd(tmp);
    tmp = _mm256_mul_pd(tmp, m);

    return _mm256_sub_pd(a, tmp);
}

Everything here is the same as it used to be in mul_modexcept for the work itself. Now the task is simple – just write mul_modusing the previously derived formula.

Look carefully, I hope you don’t get confused:

// Маска для извлечения 2-х младших байт для 32-битных целых.
const __m128i MASK = _mm_set1_epi32((1 << 16) - 1);

// Вектор, который содержит значения "2 в 16-й степени".
const __m256d K = _mm256_set1_pd(1 << 16);

__m256d mul_mod(__m256d a, __m256d b, __m256d m, __m256d m_inv) {
    // Конвертируем вектор [double] в вектор [int32_t].
    __m128i b_epi32 = _mm256_cvtpd_epi32(b);
    // Извлекаем старшие 2 байта через сдвиг вправо на 16 бит.
    __m128i b_epi32_hi = _mm_srli_epi32(b_epi32, 16);
    // Извлекаем младшие 2 байта через конъюнкцию с маской.
    __m128i b_epi32_lo = _mm_and_si128(b_epi32, MASK);

    // Конвертируем вектора [int32_t] обратно в [double].
    __m256d b_hi = _mm256_cvtepi32_pd(b_epi32_hi);
    __m256d b_lo = _mm256_cvtepi32_pd(b_epi32_lo);

    // tmp1 = ((a * b_hi) % m) * K;
    __m256d tmp1 = _mm256_mul_pd(a, b_hi);
    tmp1 = mod(tmp1, m, m_inv);
    tmp1 = _mm256_mul_pd(tmp1, K);

    // tmp2 = a * b_lo;
    __m256d tmp2 = _mm256_mul_pd(a, b_lo);

    // return (tmp1 + tmp2) % m;
    tmp2 = _mm256_add_pd(tmp1, tmp2);
    return mod(tmp2, m, m_inv);
}

Further, if desired, this code can be inlined wherever needed and manual superscalar optimizations can be performed in it. They won’t bring much effect, but they will turn the code into even more noodles, so I won’t give it here.

Time for the final test. This time I will set the upper limit of the checked interval to 1u << 31to check all 31-bit integers in general. Here's how long it took:

testMillerRabinInteger:
160.792 seconds
105097565 primes found

testMillerRabinVectorized:
408.16 seconds
105097565 primes found

2.5 times longer than the non-vectorized integer algorithm. It’s quite shameful, but at least the counted number of prime numbers coincided, which means the code is most likely correct.

conclusions

If you really need it, you can squeeze out higher performance from manual code vectorization. But the chances of this happening become lower the more integer arithmetic you use in your code. Still, vectorization is designed primarily for either integer aggregation of 32-bit numbers or floating-point arithmetic, and using it not exactly for its intended purpose (as I do) will most likely lead to suffering.

What was also interesting was how much the operator % brake. That is, I already knew that it was slow, but I did not realize the scale. The more you know.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *