CP-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Honam0905/CP-library

:heavy_check_mark: test/yosupo/Math/montgomery_multiplication.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/primality_test"
#include "Misc/marco.hpp"
#include "Misc/debug.hpp"
const int INF=1e9;
const ll INFI=1e15;
//----------Author: Nguyen Ho Nam,UIT, Saigon-----------------
#include "Modint/montgomery_multiplication.hpp"
bool isPrime(u64 x) {
    if (x < 64) {
        return (u64(1) << x) & 0x28208a20a08a28ac;
    }
    if (x % 2 == 0) {
        return false;
    }
    
    Montgomery_u64 m;
    m.set(x);
    
    const int k = __builtin_ctzll(x - 1);
    const u64 d = (x - 1) >> k, IV = m.ni, R = m.r1, R2 = m.r2, nR = m.min(R);
    
    auto mr7 = [&](u64 t1, u64 t2, u64 t3, u64 t4, u64 t5, u64 t6, u64 t7) {
        u64 r1 = R, r2 = R, r3 = R, r4 = R, r5 = R, r6 = R, r7 = R;
        t1 = mul_m64(&m, t1, R2), t2 = mul_m64(&m, t2, R2), t3 = mul_m64(&m, t3, R2);
        t4 = mul_m64(&m, t4, R2), t5 = mul_m64(&m, t5, R2), t6 = mul_m64(&m, t6, R2), t7 = mul_m64(&m, t7, R2);
        for (u64 b = d; b; b >>= 1) {
            if (b & 1) {
                r1 = mul_m64(&m, r1, t1), r2 = mul_m64(&m, r2, t2), r3 = mul_m64(&m, r3, t3);
                r4 = mul_m64(&m, r4, t4), r5 = mul_m64(&m, r5, t5), r6 = mul_m64(&m, r6, t6), r7 = mul_m64(&m, r7, t7);
            }
            t1 = mul_m64(&m, t1, t1), t2 = mul_m64(&m, t2, t2), t3 = mul_m64(&m, t3, t3);
            t4 = mul_m64(&m, t4, t4), t5 = mul_m64(&m, t5, t5), t6 = mul_m64(&m, t6, t6), t7 = mul_m64(&m, t7, t7);
        }
        r1 = min(r1, m.sub(r1,x)), r2 = min(r2, m.sub(r2,x)), r3 = min(r3, m.sub(r3,x));
        r4 = min(r4, m.sub(r4,x)), r5 = min(r5, m.sub(r5,x)), r6 = min(r6, m.sub(r6,x)), r7 = min(r7, m.sub(r7,x));
        int res1 = (r1 == R) | (r1 == nR), res2 = (r2 == R) | (r2 == nR), res3 = (r3 == R) | (r3 == nR);
        int res4 = (r4 == R) | (r4 == nR), res5 = (r5 == R) | (r5 == nR), res6 = (r6 == R) | (r6 == nR), res7 = (r7 == R) | (r7 == nR);
        for (int j = 0; j < k - 1; ++j) {
            r1 = mul_m64(&m, r1, r1), r2 = mul_m64(&m, r2, r2), r3 = mul_m64(&m, r3, r3);
            r4 = mul_m64(&m, r4, r4), r5 = mul_m64(&m, r5, r5), r6 = mul_m64(&m, r6, r6), r7 = mul_m64(&m, r7, r7);
            res1 |= (min(r1, m.sub(r1,x)) == nR), res2 |= (min(r2, m.sub(r2,x)) == nR), res3 |= (min(r3, m.sub(r3,x)) == nR);
            res4 |= (min(r4, m.sub(r4,x)) == nR), res5 |= (min(r5, m.sub(r5,x)) == nR), res6 |= (min(r6, m.sub(r6,x)) == nR), res7 |= (min(r7, m.sub(r7,x)) == nR);
        }
        return res1 & res2 & res3 & res4 & res5 & res6 & res7;
    };
    if (x == 2 || x == 3 || x == 5 || x == 13 || x == 19 || x == 73 || x == 193 || x == 407521 || x == 299210837) {
        return true;
    }
    return mr7(2, 325, 9375, 28178, 450775, 9780504, 1795265022);
}

int main() {
    int t; cin>>t;
    while(t--){
      u64 n; cin>>n;
      cout<<(isPrime(n)?"Yes":"No")<<'\n';
    }
    return 0;
}
#line 1 "test/yosupo/Math/montgomery_multiplication.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/primality_test"
#line 2 "Misc/marco.hpp"
// Judges with GCC >= 12 only needs Ofast
// #pragma GCC optimize("O3,no-stack-protector,fast-math,unroll-loops,tree-vectorize")
// MLE optimization
// #pragma GCC optimize("conserve-stack")
// Old judges
// #pragma GCC target("sse4.2,popcnt,lzcnt,abm,mmx,fma,bmi,bmi2")
// New judges. Test with assert(__builtin_cpu_supports("avx2"));
// #pragma GCC target("avx2,popcnt,lzcnt,abm,bmi,bmi2,fma,tune=native")
// Atcoder
// #pragma GCC target("avx2,popcnt,lzcnt,abm,bmi,bmi2,fma")
/*
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
typedef tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update> ods;
- insert(x),erase(x)
- find_by_order(k): return iterator to the k-th smallest element
- order_of_key(x): the number of elements that are strictly smaller
*/
#include<bits/stdc++.h>
using namespace std;

using ld = long double;
using ll = long long;
using u32 = unsigned int;
using u64 = unsigned long long;
using i128 = __int128;
using u128 = unsigned __int128;
using f128 = __float128;
 
 
#define pii pair<int,int>
#define pll pair<ll,ll>
 
#define all(x) (x).begin(),(x).end()
#define rall(x) (x).rbegin(),(x).rend()
#define ars(x) (x),(x+n)
 
#define TIME  (1.0 * clock() / CLOCKS_PER_SEC)
 
#define For(i,a,b) for (int i=(a); i<(b); i++)
#define rep(i,a) For(i,0,a)
#define rev(i,a,b) for (int i=(a); i>(b); i--)
#define FOR(i,a,b) for (int i=(a); i<=(b); i++)
#define REP(i,a) FOR(i,1,a)
#define REV(i,a,b) for (int i=(a); i>=(b); i--)
 
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define fi first
#define se second
#define FT ios_base::sync_with_stdio(false); cin.tie(nullptr);
 
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

using vi=vector<int>;
using vll = vector<ll>;
template <class T>
using vc = vector<T>;
template <class T>
using vvc = vector<vc<T>>;
template <class T>
using vvvc = vector<vvc<T>>;
template <class T>
using vvvvc = vector<vvvc<T>>;
template <class T>
using vvvvvc = vector<vvvvc<T>>;
template <class T>
using pq = priority_queue<T>;
template <class T>
using pqg = priority_queue<T, vector<T>, greater<T>>;
 
#define vv(type, name, h, ...) \
  vector<vector<type>> name(h, vector<type>(__VA_ARGS__))
#define vvv(type, name, h, w, ...)   \
  vector<vector<vector<type>>> name( \
      h, vector<vector<type>>(w, vector<type>(__VA_ARGS__)))
#define vvvv(type, name, a, b, c, ...)       \
  vector<vector<vector<vector<type>>>> name( \
      a, vector<vector<vector<type>>>(       \
             b, vector<vector<type>>(c, vector<type>(__VA_ARGS__))))
 
//template <class T>
//using ods =
//   tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
 
template <typename T> bool chkmin(T &x,T y){return x>y?x=y,1:0;}
template <typename T> bool chkmax(T &x,T y){return x<y?x=y,1:0;}
 
template<class T> using pq = priority_queue<T>;
template<class T> using pqg = priority_queue<T, vector<T>, greater<T>>;
#line 1 "Misc/debug.hpp"
void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(long long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}

template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ", "; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? ", " : ""), __print(i); cerr << "}";}
template<>
void __print(const vector<bool> &x) {int f = 0; cerr << '{'; for (size_t i = 0; i < x.size(); ++i) cerr << (f++ ? ", " : ""), __print(x[i]); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}

void dbg_out() { cerr << endl; }
template<typename Head, typename... Tail> void dbg_out(Head H, Tail... T) { __print(H); if (sizeof...(T)) cerr << ", "; dbg_out(T...); }
#define dbg(...) cerr << "[" << #__VA_ARGS__ << "]:"; dbg_out(__VA_ARGS__);
#line 4 "test/yosupo/Math/montgomery_multiplication.test.cpp"
const int INF=1e9;
const ll INFI=1e15;
//----------Author: Nguyen Ho Nam,UIT, Saigon-----------------
#line 2 "Modint/montgomery_multiplication.hpp"
/*
  inv_mod from atcoder library
  reference:https://github.com/atcoder/ac-library/blob/master/atcoder/math.hpp
*/
template<class T>
constexpr T safe_mod(T x, T m) {
    x %= m;
    if (x < 0) x += m;
    return x;
}
template<class T>
constexpr std::pair<T, T> inv_gcd(T a, T b) {
    a = safe_mod(a, b);
    if (a == 0) return {b, 0};

    T s = b, t = a;
    T m0 = 0, m1 = 1;

    while (t) {
        T u = s / t;
        s -= t * u;
        m0 -= m1 * u;
        auto tmp = s;
        s = t;
        t = tmp;
        tmp = m0;
        m0 = m1;
        m1 = tmp;
    }

    if (m0 < 0) m0 += b / s;
    return {s, m0};
}
template<class T>
T mod_inv(T x, T m) {
    assert(1 <= m);
    auto z = inv_gcd(x, m);
    assert(z.first == 1);
    return z.second;
}
/*
  montgomery multiplication
  @see https://en.wikipedia.org/wiki/Montgomery_modular_multiplication
  @see https://cp-algorithms.com/algebra/montgomery_multiplication.html
*/
struct Montgomery_u32 {
    u32 n, n2, ni, r1, r2, r3;

    void set(u32 mod) {
        n = mod;
        n2 = mod << 1;
        ni = mod;
        ni *= 2 - ni * mod;
        ni *= 2 - ni * mod;
        ni *= 2 - ni * mod;
        ni *= 2 - ni * mod;
        r1 = (u32)(int32_t)-1 % mod + 1;
        r2 = (u64)(int64_t)-1 % mod + 1;
        r3 = (u32)(((u64)r1 * (u64)r2) % mod);
    }

    u32 reduce(u64 a) {
        u32 y = (u32)(a >> 32) - (u32)(((u64)((u32)a * ni) * n) >> 32);
        return (int32_t)y < 0 ? y + n : y;
    }

    u32 to(u32 a) {
        return reduce((u64)a * r2);
    }

    u32 from(u32 a) {
        return reduce((u64)a);
    }

    u32 add(u32 a, u32 b) {
        a += b;
        a -= (a >= n ? n : 0);
        return a;
    }

    u32 sub(u32 a, u32 b) {
        a += (a < b ? n : 0);
        a -= b;
        return a;
    }

    u32 min(u32 a) {
        return sub(0, a);
    }

    u32 shl(u32 a) {
        return (a <<= 1) >= n ? a - n : a;
    }

    u32 shr(u32 a) {
        return (a & 1) ? ((a >> 1) + (n >> 1) + 1) : (a >> 1);
    }

    u32 inv(u32 a) {
        return reduce((u64)r3 * mod_inv(a, n));
    }
};

struct Montgomery_u64 {
    u64 n, n2, ni, r1, r2, r3;

    void set(u64 mod) {
        n = mod;
        n2 = mod << 1;
        ni = mod;
        ni *= 2 - ni * mod;
        ni *= 2 - ni * mod;
        ni *= 2 - ni * mod;
        ni *= 2 - ni * mod;
        ni *= 2 - ni * mod;
        r1 = (u64)(int64_t)-1 % mod + 1;
        r2 = (u128)(i128)-1 % mod + 1;
        r3 = (u64)(((u128)r1 * (u128)r2) % mod);
    }

    u64 reduce(u128 a) {
        u64 y = (u64)(a >> 64) - (u64)(((u128)((u64)a * ni) * n) >> 64);
        return (int64_t)y < 0 ? y + n : y;
    }

    u64 to(u64 a) {
        return reduce((u128)a * r2);
    }

    u64 from(u64 a) {
        return reduce((u128)a);
    }

    u64 add(u64 a, u64 b) {
        a += b;
        a -= (a >= n ? n : 0);
        return a;
    }

    u64 sub(u64 a, u64 b) {
        a += (a < b ? n : 0);
        a -= b;
        return a;
    }

    u64 min(u64 a) {
        return sub(0, a);
    }

    u64 shl(u64 a) {
        return (a <<= 1) >= n ? a - n : a;
    }

    u64 shr(u64 a) {
        return (a & 1) ? ((a >> 1) + (n >> 1) + 1) : (a >> 1);
    }

    u64 inv(u64 a) {
        return reduce((u128)r3 * mod_inv(a, n));
    }
};
// Montgomery multiplication - 32-bit
u32 mul_m32(struct Montgomery_u32 *m, u32 a, u32 b) {
    return m->reduce((u64)a * b);
}

u32 div_m32(struct Montgomery_u32 *m, u32 a, u32 b) {
    return mul_m32(m, a, m->inv(b));
}

u32 pow_m32(struct Montgomery_u32 *m, u32 a, u64 k) {
    u32 ret = m->r1;
    u64 deg = k;
    while (deg > 0) {
        if (deg & 1) {
            ret = mul_m32(m, ret, a);
        }
        a = mul_m32(m, a, a);
        deg >>= 1;
    }
    return m->from(ret);
}

// Montgomery multiplication - 64-bit
u64 mul_m64(struct Montgomery_u64 *m, u64 a, u64 b) {
    return m->reduce((u128)a * b);
}

u64 div_m64(struct Montgomery_u64 *m, u64 a, u64 b) {
    return mul_m64(m, a, m->inv(b));
}

u64 pow_m64(struct Montgomery_u64 *m, u64 a, u64 k) {
    u64 ret = m->r1, deg = k;
    while (deg > 0) {
        if (deg & 1) {
            ret = mul_m64(m, ret, a);
        }
        a = mul_m64(m, a, a);
        deg >>= 1;
    }
    return m->from(ret);
}
#line 8 "test/yosupo/Math/montgomery_multiplication.test.cpp"
bool isPrime(u64 x) {
    if (x < 64) {
        return (u64(1) << x) & 0x28208a20a08a28ac;
    }
    if (x % 2 == 0) {
        return false;
    }
    
    Montgomery_u64 m;
    m.set(x);
    
    const int k = __builtin_ctzll(x - 1);
    const u64 d = (x - 1) >> k, IV = m.ni, R = m.r1, R2 = m.r2, nR = m.min(R);
    
    auto mr7 = [&](u64 t1, u64 t2, u64 t3, u64 t4, u64 t5, u64 t6, u64 t7) {
        u64 r1 = R, r2 = R, r3 = R, r4 = R, r5 = R, r6 = R, r7 = R;
        t1 = mul_m64(&m, t1, R2), t2 = mul_m64(&m, t2, R2), t3 = mul_m64(&m, t3, R2);
        t4 = mul_m64(&m, t4, R2), t5 = mul_m64(&m, t5, R2), t6 = mul_m64(&m, t6, R2), t7 = mul_m64(&m, t7, R2);
        for (u64 b = d; b; b >>= 1) {
            if (b & 1) {
                r1 = mul_m64(&m, r1, t1), r2 = mul_m64(&m, r2, t2), r3 = mul_m64(&m, r3, t3);
                r4 = mul_m64(&m, r4, t4), r5 = mul_m64(&m, r5, t5), r6 = mul_m64(&m, r6, t6), r7 = mul_m64(&m, r7, t7);
            }
            t1 = mul_m64(&m, t1, t1), t2 = mul_m64(&m, t2, t2), t3 = mul_m64(&m, t3, t3);
            t4 = mul_m64(&m, t4, t4), t5 = mul_m64(&m, t5, t5), t6 = mul_m64(&m, t6, t6), t7 = mul_m64(&m, t7, t7);
        }
        r1 = min(r1, m.sub(r1,x)), r2 = min(r2, m.sub(r2,x)), r3 = min(r3, m.sub(r3,x));
        r4 = min(r4, m.sub(r4,x)), r5 = min(r5, m.sub(r5,x)), r6 = min(r6, m.sub(r6,x)), r7 = min(r7, m.sub(r7,x));
        int res1 = (r1 == R) | (r1 == nR), res2 = (r2 == R) | (r2 == nR), res3 = (r3 == R) | (r3 == nR);
        int res4 = (r4 == R) | (r4 == nR), res5 = (r5 == R) | (r5 == nR), res6 = (r6 == R) | (r6 == nR), res7 = (r7 == R) | (r7 == nR);
        for (int j = 0; j < k - 1; ++j) {
            r1 = mul_m64(&m, r1, r1), r2 = mul_m64(&m, r2, r2), r3 = mul_m64(&m, r3, r3);
            r4 = mul_m64(&m, r4, r4), r5 = mul_m64(&m, r5, r5), r6 = mul_m64(&m, r6, r6), r7 = mul_m64(&m, r7, r7);
            res1 |= (min(r1, m.sub(r1,x)) == nR), res2 |= (min(r2, m.sub(r2,x)) == nR), res3 |= (min(r3, m.sub(r3,x)) == nR);
            res4 |= (min(r4, m.sub(r4,x)) == nR), res5 |= (min(r5, m.sub(r5,x)) == nR), res6 |= (min(r6, m.sub(r6,x)) == nR), res7 |= (min(r7, m.sub(r7,x)) == nR);
        }
        return res1 & res2 & res3 & res4 & res5 & res6 & res7;
    };
    if (x == 2 || x == 3 || x == 5 || x == 13 || x == 19 || x == 73 || x == 193 || x == 407521 || x == 299210837) {
        return true;
    }
    return mr7(2, 325, 9375, 28178, 450775, 9780504, 1795265022);
}

int main() {
    int t; cin>>t;
    while(t--){
      u64 n; cin>>n;
      cout<<(isPrime(n)?"Yes":"No")<<'\n';
    }
    return 0;
}
Back to top page