けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 042 D - あまり (試験管黄色)

離散対数の verify に

問題概要

4 つの整数  X, P, A, B が与えられる。 P は素数である。

整数  i A \le i \le B の範囲を動くときの、 X^{i} P で割ったあまりの最小値を求めよ。

制約

  •  1 \le X \lt P \lt 2^{31}
  •  0 \le A \le B \lt 2^{31}
  •  P は素数

考えたこと

まず、

 X^{r} ≡ 1 \pmod P

を満たす最小の正の整数  r を求める (これを位数と呼ぶ)。これは離散対数によって求められる。

まず  A 以上  B 以下の整数の範囲内に  r の倍数が含まれるならば、答えは明らかに 1 である。そうでないときは、次のようにしてよい ( A, B r で割ったときの商が等しいため)。

  •  A A %  r で置き換える
  •  B B %  r で置き換える

場合分け

このあとは次のように考えることができる。

  •  B - A + 1 が小さいとき: i = A, A+1, \dots, B に対する  X^{i} \pmod P をすべて求めることができる

  •  B - A + 1 が大きいとき: X^{i} \pmod P のとりうる値が多いので、 b をランダムに選んだときに  X^{i} ≡ b \pmod P を満たす最小の  i A \le i \le B の範囲内におさまる確率が高い

前者については、単純な全探索。後者については  b = 1, 2, \dots と順に試していって条件を満たした時点で  b を返すようにすればよい。

コード

#include <bits/stdc++.h>
using namespace std;

// a^b
long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

// a^-1
long long modinv(long long a, long long m) {
    long long b = m, u = 1, v = 0;
    while (b) {
        long long t = a / b;
        a -= t * b; swap(a, b);
        u -= t * v; swap(u, v);
    }
    u %= m;
    if (u < 0) u += m;
    return u;
}

// a^x ≡ b (mod. m) となる最小の正の整数 x を求める
long long modlog(long long a, long long b, int m) {
    a %= m, b %= m;

    // calc sqrt{M}
    long long lo = -1, hi = m;
    while (hi - lo > 1) {
        long long mid = (lo + hi) / 2;
        if (mid * mid >= m) hi = mid;
        else lo = mid;
    }
    long long sqrtM = hi;

    // {a^0, a^1, a^2, ..., a^sqrt(m)} 
    map<long long, long long> apow;
    long long amari = a;
    for (long long r = 1; r < sqrtM; ++r) {
        if (!apow.count(amari)) apow[amari] = r;
        (amari *= a) %= m;
    }

    // check each A^p
    long long A = modpow(modinv(a, m), sqrtM, m);
    amari = b;
    for (long long q = 0; q < sqrtM; ++q) {
        if (amari == 1 && q > 0) return q * sqrtM;
        else if (apow.count(amari)) return q * sqrtM + apow[amari];
        (amari *= A) %= m;
    }

    // no solutions
    return -1;
}

long long ARC042D() {
    long long X, P, A, B;
    cin >> X >> P >> A >> B;
    long long r = modlog(X, 1, P);
    if (A == 0) return 1;
    if (B/r - (A-1)/r >= 1) return 1;

    A %= r, B %= r;
    if (B - A + 1 <= 1000000) {
        long long val = modpow(X, A, P);
        long long res = P;
        for (long long i = A; i <= B; ++i) {
            res = min(res, val);
            val = (val * X) % P;
        }
        return res;
    }
    else {
        for (long long b = 1; b < P; ++b) {
            long long exp = modlog(X, b, P);
            if (A <= exp && exp <= B) return b;
        }
    }
    return P;
}

int main() {
    cout << ARC042D() << endl;
}