けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

ACL Beginner Contest F - Heights and Pairs (橙色, 600 点)

こういうのは包除原理しかない!

問題へのリンク

問題概要

 2N 人の人がいる。 人  i の身長は  h_{i} である。以下の条件を満たすように、 N 組のペアを作る方法は何通りあるか、998,244,353 で割ったあまりを求めよ。

  • どの人もちょうど一つのペアに含まれる。
  • どのペアも、そのペアに属する二人の人の身長が異なる。

制約

  •  1 \le N \le 50000

考えたこと

「どのペアも、身長が異なる」という条件が無限に扱いにくい。こういうのは包除原理するしかない。「身長が異なる」という条件は余事象をとると「身長が一致する」となる。そうなればまだ扱いやすいと感じられるようになる。さて、包除原理するならば、次の値を求めることになる。

(0 組の同身長ペアを含む場合の数)
- (1 組の同身長ペアを指定して、それを含む場合の数)
+ (2 組の同身長ペアを指定して、それを含む場合の数)
- (3 組の同身長ペアを指定して、それを含む場合の数)
+ ...

まず、 2N 人を身長ごとに分類する。そして、各身長に対して以下のような DP をしたくなる。

  •  {\rm dp} \lbrack v \rbrack := これまで考えた身長グループにおいて、同身長ペアを v 組生やした場合の場合の数
  •  {\rm ndp} \lbrack v \rbrack := それに新たに身長グループを追加したとき、同身長ペアが v 組になるような場合の数

新たな身長グループの人数が  n 人であるとき、遷移は次のようになる。各  i = 0, 1, \dots, n/2 に対して、

 {\rm ndp} \lbrack k + i \rbrack +=  {\rm dp} \lbrack k \rbrack \times \frac{{}_{n}{\rm P}_{2i}}{2^{i}}

これでとりあえず  O(N^{2}) な解法が生まれた。

FFT へ

上の DP 更新式を見ると、 {\rm dp} \lbrack k \rbrack からの遷移が添字  k に依存しないことがわかる。このような場合、DP 遷移を畳み込み積として表せる。具体的には、以下の畳み込みとみなせる。

  •  ({\rm dp} \lbrack 0 \rbrack, {\rm dp} \lbrack 1 \rbrack, {\rm dp} \lbrack 2 \rbrack, \dots)
  •  (\frac{{}_{n}{\rm P}_{0}}{2^{0}}, \frac{{}_{n}{\rm P}_{2}}{2^{1}}, \frac{{}_{n}{\rm P}_{4}}{2^{2}}, \dots)

よって FFT で扱うことができる。つまり各身長グループの人数を  n_{0}, \dots, n_{K-1} としたとき、

  •  (\frac{{}_{n_{0}}{\rm P}_{0}}{2^{0}}, \frac{{}_{n_{0}}{\rm P}_{2}}{2^{1}}, \frac{{}_{n_{0}}{\rm P}_{4}}{2^{2}}, \dots)
  • ...
  •  (\frac{{}_{n_{K-1}}{\rm P}_{0}}{2^{0}}, \frac{{}_{n_{K-1}}{\rm P}_{2}}{2^{1}}, \frac{{}_{n_{K-1}}{\rm P}_{4}}{2^{2}}, \dots)

の畳み込み積を求めていけばよい。しかしなんの工夫もしないと、 O(N^{2} \log N) の計算量となってかえって悪化してしまうように思える。

計算量削減

ほんの少しの工夫で一気に計算量が削減できる。 a 次多項式と  b 次多項式の積をとると  a + b 次多項式となることに注目する。この積をとるときの計算量は  c = \max(a, b) として  O(c \log c) となる。

さて、一般に  m 個の多項式の積をとるとき、以下のようにすると計算量が最小となることは明らかである (ハフマン符号と同じ Greedy)。


「次数の低い 2 個の多項式をとりだして、それらを掛け合わせる」という操作を、多項式が 1 個にマージされるまでやる


そして、なんとこのハフマン符号的な方法によって、計算量が  O(N (\log N)^{2}) で抑えられるのだ。それを簡単に評価してみよう。

上の方法より良くはならない方法として、次のような分割統治的な方法がある。もしこれが  O(N (\log N)^{2}) で抑えられるならば、上記の最良の方法も当然  O(N (\log N)^{2}) で抑えられることになる。


「今ある多項式を隣接する 2 個ずつをマージしていく」という操作を、多項式が 1 個にマージされるまでやる。
(1 回のステップでは、多項式が  m 個あるときは  m/2 個に半減する)


実は、このような分割統治的な方法でも  O(N (\log N)^{2}) の計算量となる。なぜならば、どの多項式もそれが積計算に絡む回数が  O(\log N) 回でおさえられるからだ。つまり、多項式のマージ過程を考えると下図 (解説放送より) のような強平衡二分木となって、その高さは  O(\log N) となる。

そしてどの階層 (深さ) をみても、それらの多項式の次数の総和は  N となっているので、

  • 「隣接する 2 個ずつマージする」という 1 ステップの計算量は  O(N \log N)
  • そのステップが  O(\log N) 回あるので、全体の計算量は  O(N (\log N)^{2})

となる。

以上から、 O(N (\log N)^{2}) で数え上げられることがわかった。

コード

ここでは自前の NTT ライブラリを用いた。

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    constexpr bool operator < (const Fp& r) const noexcept {
        return this->val < r.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

// NTT
namespace NTT {
    long long modpow(long long a, long long n, int mod) {
        long long res = 1;
        while (n > 0) {
            if (n & 1) res = res * a % mod;
            a = a * a % mod;
            n >>= 1;
        }
        return res;
    }

    long long modinv(long long a, int mod) {
        long long b = mod, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        u %= mod;
        if (u < 0) u += mod;
        return u;
    }

    int calc_primitive_root(int mod) {
        if (mod == 998244353) return 3;
        int divs[20] = {};
        divs[0] = 2;
        int cnt = 1;
        long long x = (mod - 1) / 2;
        while (x % 2 == 0) x /= 2;
        for (long long i = 3; i * i <= x; i += 2) {
            if (x % i == 0) {
                divs[cnt++] = i;
                while (x % i == 0) x /= i;
            }
        }
        if (x > 1) divs[cnt++] = x;
        for (int g = 2;; g++) {
            bool ok = true;
            for (int i = 0; i < cnt; i++) {
                if (modpow(g, (mod - 1) / divs[i], mod) == 1) {
                    ok = false;
                    break;
                }
            }
            if (ok) return g;
        }
    }

    int get_fft_size(int N, int M) {
        int size_a = 1, size_b = 1;
        while (size_a < N) size_a <<= 1;
        while (size_b < M) size_b <<= 1;
        return max(size_a, size_b) << 1;
    }

    // number-theoretic transform
    template<class mint> void trans(vector<mint> &v, bool inv = false) {
        if (v.empty()) return;
        int N = (int)v.size();
        int MOD = v[0].getmod();
        int PR = calc_primitive_root(MOD);
        static bool first = true;
        static vector<long long> vbw(30), vibw(30);
        if (first) {
            first = false;
            for (int k = 0; k < 30; ++k) {
                vbw[k] = modpow(PR, (MOD - 1) >> (k + 1), MOD);
                vibw[k] = modinv(vbw[k], MOD);
            }
        }
        for (int i = 0, j = 1; j < N - 1; j++) {
            for (int k = N >> 1; k > (i ^= k); k >>= 1);
            if (i > j) swap(v[i], v[j]);
        }
        for (int k = 0, t = 2; t <= N; ++k, t <<= 1) {
            long long bw = vbw[k];
            if (inv) bw = vibw[k];
            for (int i = 0; i < N; i += t) {
                mint w = 1;
                for (int j = 0; j < t/2; ++j) {
                    int j1 = i + j, j2 = i + j + t/2;
                    mint c1 = v[j1], c2 = v[j2] * w;
                    v[j1] = c1 + c2;
                    v[j2] = c1 - c2;
                    w *= bw;
                }
            }
        }
        if (inv) {
            long long invN = modinv(N, MOD);
            for (int i = 0; i < N; ++i) v[i] = v[i] * invN;
        }
    }

    // small case (T = mint, long long)
    template<class T> vector<T> naive_mul 
    (const vector<T> &A, const vector<T> &B) {
        if (A.empty() || B.empty()) return {};
        int N = (int)A.size(), M = (int)B.size();
        vector<T> res(N + M - 1);
        for (int i = 0; i < N; ++i)
            for (int j = 0; j < M; ++j)
                res[i + j] += A[i] * B[j];
        return res;
    }

    // mul
    template<class mint> vector<mint> mul
    (const vector<mint> &A, const vector<mint> &B) {
        if (A.empty() || B.empty()) return {};
        int N = (int)A.size(), M = (int)B.size();
        if (min(N, M) < 30) return naive_mul(A, B);
        int size_fft = get_fft_size(N, M);

        vector<mint> a(size_fft), b(size_fft), c(size_fft);
        for (int i = 0; i < N; ++i) a[i] = A[i];
        for (int i = 0; i < M; ++i) b[i] = B[i];
        trans(a), trans(b);
        vector<mint> res(size_fft);
        for (int i = 0; i < size_fft; ++i) res[i] = a[i] * b[i];
        trans(res, true);
        res.resize(N + M - 1);
        return res;
    }
};

// Binomial coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N;
    cin >> N;
    map<int,long long> ma;
    for (int i = 0; i < N*2; ++i) {
        int h;
        cin >> h;
        ma[h]++;
    }
    BiCoef<mint> bc(N*2+1);

    priority_queue<pair<int,vector<mint>>, vector<pair<int,vector<mint>>>, greater<pair<int,vector<mint>>>> que;
    for (auto it : ma) {
        int n = it.second;
        vector<mint> pol(n/2+1, 1);
        for (int i = 0; i <= n/2; ++i) {
            pol[i] = bc.fact(n) * bc.finv(n - i*2) * bc.finv(i) / modpow(mint(2), i);
        }
        que.push({pol.size(), pol});
    }
    while (que.size() >= 2) {
        auto f = que.top().second; que.pop();
        auto g = que.top().second; que.pop();
        auto h = NTT::mul(f, g);
        que.push({h.size(), h});
    }
    auto func = que.top().second;

    mint res = 0;
    for (int k = 0; k < func.size(); ++k) {
        mint fac = bc.fact(N*2 - k*2) * bc.finv(N-k) / modpow(mint(2), N-k);
        if (k % 2 == 0) res += func[k] * fac;
        else res -= func[k] * fac;
    }
    cout << res << endl;
}