けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 2378 SolveMe (JAG冬合宿 2010 day3-J)

伝説の良難問。
現在でこそ AC 数 30 人で解説記事も豊富にあるが、当時は AC 数 3 人という状況で解説も無い中で、必死に 1 週間かけて通した想い出の問題。

問題へのリンク

問題概要

正の整数  N 0 以上の整数  X, Y, Z が与えられる。
{ 0, 1, \dots, N-1} から { 0, 1, \dots, N-1} への写像  f, g の組であって、以下の条件を満たすものが何通りあるか、1000000007 で割ったあまりを求めよ。

  • id を恒等写像として、 f^{Z} g f^{Y} g f^{X} = id

制約

  •  1 \le N \le 1000
  •  0 \le X, Y, Z \le 10^{18}

考えたこと

順列に関する問題!!!
順列を写像として見て扱うのは、以下の問題でもお馴染み。

drken1215.hatenablog.com

まず、 f^{Z} g f^{Y} g f^{X}全単射であることから、 g は少なくとも全単射でなければならず、 f については

  •  X = Y = Z = 0 のときは、 f は任意の写像でよい ( N^{N} 通り)
  • それ以外のときは、 f全単射でなければならない

 X = Y = Z = 0 の場合がコーナーケースとなっている。なお、コーナーケースの場合は  f の個数がイレギュラーであることを除けば、以下の議論はそう大差ない。

式変形

 f, g がともに全単射である場合 (順列となっている場合) について、式変形してみる。全単射なので写像が存在する。よって  f^{-1} g^{-1} を考えてよい。左から  f^{Y - Z} を、右から  f^{-X} をかけることで、以下のようになる。

 f^{Z} g f^{Y} g f^{X} = id
 f^{Y} g f^{Y} g = f^{-X + Y - Z}

ここで改めて  h = f^{Y} g とおくと、 (f, g) の組と、 (f, h) の組とは一対一に対応するので、結局

  •  h^{2} = f^{-X + Y - Z}

を満たすような順列  f, h の組を求めればよいことになる。なお、 -X + Y - Z が負である場合には、改めて  k = h^{-1} として、

  •  k^{2} = f^{X - Y + Z}

となるので、 -X + Y - Z は 0 以上であるとしてよい (絶対値をとった値で考えてよい)。以上をまとめると、 t = |-X + Y - Z| として、記号を置き直して、

  •  g^{2} = f^{t}

を満たすような順列  (f, g) の組を求める問題ということになる。

数え上げの方針

ここまでくると簡単に思える方もいるらしい...僕はここまではむしろ一瞬だったのに、ここから詰め切るのに 1 週間かかった ( X = Y = Z = 0 というコーナーケースには気づいてなかった)。

方針としては、各順列  h に対して

  •  f^{t} = h
  •  g^{2} = h

を満たすような  f g をそれぞれ数え上げて、その積をとる。それを  N! 通りの  h に対して総和をとる、という方針をとることにする。 f の方が一般的なので  f について考える。

順列 f, h を巡回群の直積に分解して考える

整数問題では、整数を素因数分解できることから、まず素数の場合を解いて、素数べきの場合を解いて、最後に全体をまとめるようなアプローチをよくする。順列に対しても、巡回群の直積に分解できることから、まず巡回群の場合を解くのは自然だと思われる。

一旦仮に、 f を位数  e巡回群であるとしてみよう ( f, h の定義域のサイズは一時的に  N ではなく  e であるとする)。

このとき、たとえば  t e とが互いに素であったならば、 f h とは一対一対応することがわかる (より正確にいえば、 f に対して  f^{t} を返す写像は、「位数  e巡回群の集合」から「位数  e巡回群の集合」への全単射である。一般の場合には、

  •  h = f^{t} の位数  e' は、 g = {\rm GCD}(t, d) として、 e' = \frac{e}{g} となる
  •  t' = \frac{t}{g} と置くと  e' t'互いに素であり、 f を位数  e巡回群として考えている限りは、 f^{g} = h を満たすような  (f, h) の組の個数を数えても等しい

ということがわかる。よって代わりにこれを数える。 f が位数  e = e'g巡回群であるとき、 h = f^{g} は「位数  e'巡回群 g 個の直積」となる。さて、まず、そのような  h を一つ固定したときに、それに対応する  f が何個あるのかを考えよう。

  •  f は位数  e = e'g巡回群であるが、 f^{g} を計算すると、 f の各元は  g 個先の元へと移ることになる (つまり位数は  e' となる)
  •  h を固定したとき、 f の各元が  g 個進んでどの元になるのかは一意に決まるので、よって  f において各元は  e' 個のメンバをもつ  g 個の「数珠」に分かれることになる
  • そこから  f を構成することは、各数珠の順序を固定した上で、数珠を繋ぎ直すことに相当する
    •  g 個の数珠の順序を決める:  (g-1)! 通り
    •  g 個の数珠それぞれについて、中身の offset を決める:  e'^{g-1} 通り

ということで、「位数  e'巡回群 g 個の直積」であるような  h に対して、「位数  e = e'g巡回群」であるような  f の個数は、 1 :  (g-1)! e'^{g-1} に対応することがわかる。

DP へ

以上の考察から、 p = t, q = 2 として、 h = f^{p} = q^{q} を満たす  f, g (, h) の組の個数を求めるために、以下の DP を回すことを考えよう:

  • dp[  n ][  e ] :=  f, g, h 1, 2, \dots, n の順列としたときに、 h の位数が  e 以下であるようなものの個数

そして、次のような更新式が立つ。  C(n+ke, n) の部分は挿入 DP の考え方を使っていて、 \frac{(ke)!}{k!(e!)^{k}} の部分は「 ke 個のものを  k グループに分けるグループ分けの個数」である。

  • dp[  e ] [  n + ke ] += dp[  e-1 ][  n ] ×  C(n+ke, n) ×  \frac{(ke)!}{k!(e!)^{k}} ×  ((e-1)!)^{k} × ( f^{p} = h となるような  f の個数) × ( g^{q} = h となるような  g の個数)

( h を、位数  e巡回群 k 個直積の一つとする)

最後の DP

最後に、 f^{p} = h となるような  f の個数を考える。各  e に対して、 h は位数  e巡回群 k 個の直積である。さて、 f を構成する巡回群の位数を  d とするとき、

  •  g = {\rm GCD}(d, p) として、 e = \frac{d}{g} となる
  • よって、 e の倍数  d のうち、 e = \frac{d}{{\rm GCD}(d, p)} となるものを考えれば良い

そのような  d d_{0}, d_{1}, \dots, d_{m-1} とする。このとき、

  • dp2[  i ][  k ] :=  d として最初の  i 種類のみを考えた場合についての、 f^{p} h のうちの最初の「位数が  e巡回群 k 個の直積」となるような  f の個数

とする。以下のような DP が立つ。各  d = d_{i} に対して  g = \frac{d}{e} として、

  • dp2[  i + 1 ][  k + lg ] += dp2[  i ][  k ] ×  C(k + lg, k) ×  \frac{(lg)!}{l!(g!)^{l}} ×  ((g-1)! e^{g-1})^{l}

と更新できる。 C(k + lg, k) の部分は挿入 DP の思想に従って  h に新たに  lg 個の巡回群 (位数  e) を付け加えることを表す (ここを  C(ke + ld, ke) としてしまい、バグに苦しんだ)。 \frac{(lg)!}{l!(g!)^{l}} の部分は、 f^{p} をなす  g \times l 個の巡回群 l 個のグループに分ける方法の個数を表す。

コード

必死のデバッグの跡

https://ideone.com/7ApGye

#include <iostream>
#include <vector>
#include <algorithm>
#include <map>
using namespace std;

long long GCD(long long x, long long y) {
    if (y == 0) return x;
    else return GCD(y, x % y);
}

// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

// 二項係数ライブラリ
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;
BiCoef<mint> bc;

// e*k 人を、e 人の k グループに分ける方法の数
mint grouping(int e, int k) {
    return bc.fact(k*e) * bc.finv(k) * modpow(bc.finv(e), k);
}

// res[k] = h を位数 e の巡回群の k 個の直積からなる 1 つの順列としたときに
// f^p = h となる f が何個あるか
vector<mint> sub(int N, long long p, long long e) {
    int K = N / e;

    // f を構成する巡回群の位数 d として、e = d / GCD(d, p) となる d のみを考える
    vector<long long> D;
    for (long long d = e; d <= N; d += e) {
        if (e * GCD(d, p) == d) D.push_back(d);
    }
    
    // DP
    vector<vector<mint>> dp(D.size()+1, vector<mint>(K+1, 0));
    dp[0][0] = 1;
    for (int i = 0; i < D.size(); ++i) {
        int d = D[i];
        int g = d / e;
        mint mul = bc.fact(g - 1) * modpow(mint(e), g - 1);
        for (int k = 0; k <= K; ++k) {
            mint fac = 1;
            for (int l = 0; k + l * g <= K; ++l) {
                dp[i + 1][k + l * g] += dp[i][k] *
                    bc.com(k + l * g, k) * grouping(g, l) * fac;
                fac *= mul;
            }
        }
    }
    return dp[D.size()];
}

mint solve(int N, long long X, long long Y, long long Z) {
    if (X + Y + Z == 0) {
        auto gdp = sub(N, 2, 1);
        return gdp[N] * modpow(mint(N), N);
    }
    long long t = X - Y + Z;
    if (t < 0) t = -t;
    vector<vector<mint>> dp(N+1, vector<mint>(N+1, 0));
    dp[0][0] = 1;
    for (int e = 1; e <= N; ++e) {
        const vector<mint>& fdp = sub(N, t, e);
        const vector<mint>& gdp = sub(N, 2, e);
        for (int n = 0; n <= N; ++n) {
            mint fac = 1;
            for (int k = 0; n + k * e <= N; ++k) {
                dp[e][n + k * e] += dp[e - 1][n] *
                    bc.com(n + k * e, n) * grouping(e, k) * fac * fdp[k] * gdp[k];
                fac *= bc.fact(e - 1);
            }
        }
    }
    return dp[N][N];
}

int main() {
    bc.init(5100);
    int N;
    long long X, Y, Z;
    while (cin >> N >> X >> Y >> Z) cout << solve(N, X, Y, Z) << endl;
}