けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Yosupo Library Checker - Primality Test

Miller-Rabin 法や、モンゴメリ乗算を試せる問題ですね!

問題概要

クエリが  Q 個与えられる。

各クエリでは正整数  N が与えられるので、素数かどうか判定せよ。

制約

  •  1 \le Q \le 10^{5}
  •  1 \le N \le 10^{18}

解法

Miller-Rabin 法が使える。次のようなアルゴリズムである。

  •  N = 2 のとき:素数である
  •  N 2 以外の偶数であるとき:合成数である

以下、 N 3 以上の奇数とする。 N - 1 = 2^{s} d ( d は奇数) と表す。

 1 以上  N-1 以下の整数  a に対して、

  •  a^{d} \equiv 1 \pmod{N}
  •  a^{d} \equiv -1 \pmod{N}
  •  a^{2d} \equiv -1 \pmod{N}
  •  \dots
  •  a^{2^{s-2}d} \equiv -1 \pmod{N}
  •  a^{2^{s-1}d} \equiv -1 \pmod{N}

のいずれかが成り立つことを「 a N のテストに通る」と呼ぶことにする。 N のテストに通らない  a が検出されたならば、直ちに  N は合成数であると言える。

また、

  •  N \lt 4759123141 のとき: a = 2, 7, 61
  •  N \le 2^{64} のとき: a = 2, 325, 9375, 28178, 450775, 9780504, 1795265022

に対してテストに通るならば、 N は素数であると言い切れる。

drken1215.hatenablog.com

計算量を評価する。 \mod{N} における整数の乗算を  O(1) でできるものとして、テストに用いる  a の個数を  K としたとき、Miller-Rabin 法の計算量は  O(K \log N) と評価できる。

よって、全体の計算量は  O(QK\log N) となる。

コード

 \mod{N} における整数の乗算を実現する際には、64 bit 整数で表現できない大きさの整数が登場することが問題となる。

ここでは、モンゴメリ乗算を用いて高速化する。モンゴメリ乗算は、時間のかかる除算を実質的に行うことなく、乗算・加算・減算・シフト演算のみで、効率的に整数の積の剰余を求めることのできるアルゴリズムである。

モンゴメリ乗算を用いると、32 bit 整数に収まらない整数を法とする効率的な modint を作れる。

#include <bits/stdc++.h>
using namespace std;

// montgomery modint (MOD < 2^62, MOD is odd)
struct MontgomeryModInt64 {
    using mint = MontgomeryModInt64;
    using u64 = uint64_t;
    using u128 = __uint128_t;
    
    // static menber
    static u64 MOD;
    static u64 INV_MOD;  // INV_MOD * MOD ≡ 1 (mod 2^64)
    static u64 T128;  // 2^128 (mod MOD)
    
    // inner value
    u64 val;
    
    // constructor
    MontgomeryModInt64() : val(0) { }
    MontgomeryModInt64(long long v) : val(reduce((u128(v) + MOD) * T128)) { }
    u64 get() const {
        u64 res = reduce(val);
        return res >= MOD ? res - MOD : res;
    }
    
    // mod getter and setter
    static u64 get_mod() { return MOD; }
    static void set_mod(u64 mod) {
        assert(mod < (1LL << 62));
        assert((mod & 1));
        MOD = mod;
        T128 = -u128(mod) % mod;
        INV_MOD = get_inv_mod();
    }
    static u64 get_inv_mod() {
        u64 res = MOD;
        for (int i = 0; i < 5; ++i) res *= 2 - MOD * res;
        return res;
    }
    static u64 reduce(const u128 &v) {
        return (v + u128(u64(v) * u64(-INV_MOD)) * MOD) >> 64;
    }
    
    // arithmetic operators
    mint operator - () const { return mint() - mint(*this); }
    mint operator + (const mint &r) const { return mint(*this) += r; }
    mint operator - (const mint &r) const { return mint(*this) -= r; }
    mint operator * (const mint &r) const { return mint(*this) *= r; }
    mint operator / (const mint &r) const { return mint(*this) /= r; }
    mint& operator += (const mint &r) {
        if ((val += r.val) >= 2 * MOD) val -= 2 * MOD;
        return *this;
    }
    mint& operator -= (const mint &r) {
        if ((val += 2 * MOD - r.val) >= 2 * MOD) val -= 2 * MOD;
        return *this;
    }
    mint& operator *= (const mint &r) {
        val = reduce(u128(val) * r.val);
        return *this;
    }
    mint& operator /= (const mint &r) {
        *this *= r.inv();
        return *this;
    }
    mint inv() const { return pow(MOD - 2); }
    mint pow(u128 n) const {
        mint res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }

    // other operators
    bool operator == (const mint &r) const {
        return (val >= MOD ? val - MOD : val) == (r.val >= MOD ? r.val - MOD : r.val);
    }
    bool operator != (const mint &r) const {
        return (val >= MOD ? val - MOD : val) != (r.val >= MOD ? r.val - MOD : r.val);
    }
    friend istream& operator >> (istream &is, mint &x) {
        long long t;
        is >> t;
        x = mint(t);
        return is;
    }
    friend ostream& operator << (ostream &os, const mint &x) {
        return os << x.get();
    }
    friend mint modpow(const mint &r, long long n) {
        return r.pow(n);
    }
    friend mint modinv(const mint &r) {
        return r.inv();
    }
};

typename MontgomeryModInt64::u64
MontgomeryModInt64::MOD, MontgomeryModInt64::INV_MOD, MontgomeryModInt64::T128;

// Miller-Rabin
bool MillerRabin(long long N, vector<long long> A) {
    using mint = MontgomeryModInt64;
    mint::set_mod(N);
    
    long long s = 0, d = N - 1;
    while (d % 2 == 0) {
        ++s;
        d >>= 1;
    }
    for (auto a : A) {
        if (N <= a) return true;
        mint x = mint(a).pow(d);
        if (x != 1) {
            long long t;
            for (t = 0; t < s; ++t) {
                if (x == N - 1) break;
                x *= x;
            }
            if (t == s) return false;
        }
    }
    return true;
}

bool is_prime(long long N) {
    if (N <= 1) return false;
    else if (N == 2) return true;
    else if (N % 2 == 0) return false;
    else if (N < 4759123141LL)
        return MillerRabin(N, {2, 7, 61});
    else
        return MillerRabin(N, {2, 325, 9375, 28178, 450775, 9780504, 1795265022});
}

void YosupoPrimarityTest() {
    cin.tie(0);
    ios::sync_with_stdio(false);
    int N;
    cin >> N;
    for (int i = 0; i < N; ++i) {
        long long A;
        cin >> A;
        if (is_prime(A))
            cout << "Yes" << endl;
        else
            cout << "No" << endl;
    }
}

int main() {
    YosupoPrimarityTest();
}