けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 012 D - Don't worry. Be Together (試験管赤色)

経路数に帰着して頑張って二項係数計算するところは面白かった!
そのあとの任意 mod の二項係数処理は、今の AtCoder ではあまり出なさそうかな。

問題へのリンク

問題概要

 N 個の座標  (x_i, y_i) についての以下の値  f(i) を積算した値を  M で割ったあまりを求めよ。

原点から座標  (x_i, y_i) へと移動したい。1 ステップごとに上下左右いずれかの方向に 1 だけ移動できる。ちょうど  T ステップで  (x, y) へと到達する方法の場合の和を  f(i) とする。

制約

  •  0 \le |x|, |y|, T \le 10^{6}
  •  1 \le M \le 1000000007

考えたこと

  x, y \ge 0 として一般性を失わない。

もし  x + y = T であったら最短経路で行かないといけないので  C(x + y, x) 通りになる。 T <  x + y だったら不可能だし、 T \ge x + y のときも  T - (x + y) が奇数だったらダメ。ここで  r = \frac{T - (x + y)}{2} として、 r 回のうち  i 回が左右移動、 j 回が上下移動だとするとその場合の数は

  • 右移動:  x + i
  • 左移動:  i
  • 上移動:  y + j
  • 下移動:  j

であることから

  •  \frac{(x + y + 2r)!}{(x+i)!i!(y+j)!j!}

となる。これを  0 \le i \le r について総和をとることになる。二項係数の和はすごく頑張ると「経路数の和」みたいな形にすることで綺麗になったりするイメージがある。今回も頑張ってみる。

ちょっと式変形すると

 \frac{(x + y + 2r)!}{r!(x+y+r)!} \sum_{i=0}^{r} \frac{r!}{i!(r-i)!} \frac{(x+y+r)!}{(x+i)!(y+r-i)!}  = C(x + y+2r, r) \sum_{i=0}^{r} C(r, i) C(x+y+r, x+i)

になることがわかる。和になっているところはヴァンデルモンドのたたみ込みの式になっていて、

 C(x+y+2r, r)C(x+y+2r, x+r) = C(T, r) C(T, x+r)

と求められることがわかった。今の AtCoder ならここまでを問う出題になりそうだけど、本問題ではこれを任意 mod でやらなければならない...!!!

方針としては、この問題の場合にはいもす法でやるのが主流みたいだけど、僕は中国剰余定理で頑張ることにした。

  •  M = p_{1}^{e_{1}} p_{2}^{e_{2}} \dots p_{k}^{e_{k}} と素因数分解する
  •  C(n, r) {\rm mod} p^{k} を頑張って求めることを考える、それができれば中国剰余定理で復元できる
  •  C(n, r) = p^{e} q として、 e q {\rm mod} p^{k} を求める

という感じ。

#include <iostream>
#include <vector>
#include <map>
#include <cstring>
using namespace std;


const int MAX = 110000;

// p is prime, m = p^k
// a! = p^(ord[a]) * fac[a] (mod. m)
// (a!)^-1 = p^(-ord[a]) * finv[a] (mod. m);
long long ord[MAX], fac[MAX];
void prime_com_init(long long p, long long pm) {
    ord[0] = ord[1] = 0;
    fac[0] = fac[1] = 1;
    for (int i = 2; i < MAX; i++) {
        long long add = 0;
        long long ni = i;
        while (ni % p == 0) ++add, ni /= p;
        ord[i] = ord[i-1] + add;
        fac[i] = fac[ni-1] * ni % pm;
    }
}

inline long long mod(long long a, long long m) {
    return (a % m + m) % m;
}

long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

long long modinv(long long a, long long mod) {
    long long b = mod, u = 1, v = 0;
    while (b) {
        long long t = a/b;
        a -= t*b; swap(a, b);
        u -= t*v; swap(u, v);
    }
    u %= mod;
    if (u < 0) u += mod;
    return u;
}

// nCr mod. pm
long long COM(long long n, long long r, long long p, long long pm) {
    if (n < 0 || r < 0 || n < r) return 0;
    long long e = ord[n] - ord[r] - ord[n-r];
    long long res = fac[n] * modinv(fac[r] * fac[n-r] % pm, pm) % pm;
    res = res * modpow(p, e, pm) % pm;
    return res;
}

long long extGcd(long long a, long long b, long long &p, long long &q) {
    if (b == 0) { p = 1; q = 0; return a; }
    long long d = extGcd(b, a%b, q, p);
    q -= a/b * p;
    return d;
}

pair<long long, long long> chinese_rem(const vector<long long> &b, const vector<long long> &m) {
    long long r = 0, M = 1;
    for (int i = 0; i < (int)b.size(); ++i) {
        long long p, q;
        long long d = extGcd(M, m[i], p, q); // p is inv of M/d (mod. m[i]/d)
        if ((b[i] - r) % d != 0) return make_pair(0, -1);
        long long tmp = (b[i] - r) / d * p % (m[i]/d);
        r += M * tmp;
        M *= m[i]/d;
        r = mod(r, M);
    }
    return make_pair(mod(r, M), M);
}

vector<pair<long long, long long> > prime_factorize(long long n) {
    vector<pair<long long, long long> > res;
    for (long long p = 2; p * p <= n; ++p) {
        if (n % p != 0) continue;
        int num = 0;
        while (n % p == 0) { ++num; n /= p; }
        res.push_back(make_pair(p, num));
    }
    if (n != 1) res.push_back(make_pair(n, 1));
    return res;
}



int N, T, M;
vector<int> x, y, r;

long long solve() {
    cin >> N >> T >> M;
    x.resize(N); y.resize(N); r.resize(N);
    bool ok = true;
    for (int i = 0; i < N; ++i) {
        cin >> x[i] >> y[i];
        if (x[i] < 0) x[i] = -x[i];
        if (y[i] < 0) y[i] = -y[i];
        if (T < x[i] + y[i]) ok = false;
        if ((T - x[i] - y[i]) % 2 != 0) ok = false;
        r[i] = (T - x[i] - y[i]) / 2;
    }
    if (!ok) return 0;
    if (M == 1) return 0;

    vector<pair<long long, long long> > pf = prime_factorize(M);
    vector<long long> vb, vm;
    for (auto ps : pf) {
        long long p = ps.first, e = ps.second;
        long long pm = 1;
        for (int i = 0; i < e; ++i) pm *= p;
        prime_com_init(p, pm);
        long long b = 1;
        for (int i = 0; i < N; ++i) {
            b *= COM(T, r[i], p, pm) * COM(T, x[i] + r[i], p, pm) % pm;
            b %= pm;
        }
        vm.push_back(pm);
        vb.push_back(b);
    }
    auto res = chinese_rem(vb, vm);
    return res.first;
}

int main() {
    cout << solve() << endl;
}