けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces #460 (Div. 2) E. Congruence Equation (R2100)

中国剰余定理のことをあれこれ調べていたら勢いで解いたん

問題へのリンク

問題概要

整数  a, b, p, x が与えられる。

 Na^{N} ≡ b (mod.  p)

が成立するような  1 以上  x 以下の整数  N を数え上げよ。

制約

解法

Fermat の小定理から以下のことが言える:

  •  N は mod.  p において周期  p である
  •  a^{N} は mod.  p において周期  p-1 である

したがって、 p p-1 は互いに素だから、 Na^{N} mod.  p は全体として周期  p(p-1) となることがわかる。

ここで、 a^{N} の方を固定して集計することを考える。すなわち、 N ≡ k (mod.  p-1) の場合について数えて条件を満たす  N の個数を数え上げて、それを  k = 0, 1, \dots, p-2 について合算する。

 Na^{k} ≡ b (mod.  p)
 N ≡ a^{-k}b (mod.  p)

となることから、すなわち

  •  N ≡ k (mod.  p-1)
  •  N ≡ a^{-k}b (mod.  p)

を満たす  N を数え上げることになる。中国剰余定理により、

 N ≡ r (mod.  p(p-1))

を満たすような  r が求められるので、それを用いて  x 以下の整数を数え上げる。

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

inline long long mod(long long a, long long m) {
    return (a % m + m) % m;
}

long long inv(long long a, long long m) {
    long long b = m, u = 1, v = 0;
    while (b) {
        long long t = a/b;
        a -= t*b; swap(a, b);
        u -= t*v; swap(u, v);
    }
    return mod(u, m);
}

long long gcd(long long a, long long b) {
    if (b == 0) return a;
    else return gcd(b, a % b);
}

pair<long long, long long> ChineseRem(long long b1, long long m1, long long b2, long long m2) {   
        long long x = 0, m = 1;
        long long a = m, b = b1 - x, d = gcd(m1, a);
        if (b % d != 0) return make_pair(0, -1);
        long long t = mod(b / d * inv(a / d, m1 / d), m1 / d);
        x += m * t;
        m *= m1 / d;
        a = m, b = b2 - x, d = gcd(m2, a);
        if (b % d != 0) return make_pair(0, -1);
        t = mod(b / d * inv(a / d, m2 / d), m2 / d);
        x += m * t;
        m *= m2 / d;
        return make_pair(x % m, m);
}

long long a, b, p, x;

int main() {
  cin >> a >> b >> p >> x;
  long long res = 0;
  long long pow = 1;
  for (long long k = 0; k < p-1; ++k) {
    // ≡ k(mod. p-1), ≡ a^{-k}b (mod. p)
    long long b2 = inv(pow, p) * b % p;
    pair<long long, long long> c = ChineseRem(k, p-1, b2, p);
    long long tmp = (x + (c.second - c.first) % c.second) / c.second;
    res += tmp;
    pow = pow * a % p;
  }
  cout << res << endl;
}