けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 106 F - Figures (橙色, 800 点)

コンテスト本番、こっちをやればよかった...ところで解説が天才すぎる!

問題概要

 N 個の部品と、 N-1 個の接続用部品とがある。これらを用いてフィギュアを作ろうとしている。

 i 番目の部品には  d_{i} 個の穴がついている。接続用部品は、2 個の部品を選んで、それぞれの部品の穴を選択して挿し込むことができるようになっている。

フィギュアは  N 個の部品が連結でなければならない (木になる)。フィギュアの組み立て方が何通りあるか、998244353 で割ったあまりを求めよ。木として同一のものであっても、異なる穴組を接続したものは異なるものとみなす。

制約

  •  2 \le N \le 2 \times 10^{5}
  •  0 \le d_{i} \lt 998244353

考えたこと

まず、次の有名事実がある。


頂点数が  N であるような完全グラフの全域木であって、頂点  v の次数が  d_{v} ( d_{v} の総和が  2N-2) であるものの個数は

 \frac{(N-2)!}{(d_{1}-1)!(d_{2}-1)! \dots (d_{N}-1)!}

で与えられる。


このことの証明は、以下の記事に詳しく書いた。

drken1215.hatenablog.com

これを利用すると、以下を求めればよいことがわかる。

 \sum_{e_{i} \ge 1, e_{1} + \dots + e_{N} = 2N-2} (\frac{(N-2)!}{(e_{1}-1)! \dots (e_{N}-1)!} \times \frac{d_{1}!}{(d_{1} - e_{1})!} \dots \frac{d_{N}!}{(d_{N} - e_{N})!})
 = \sum_{e_{i} \ge 0, e_{1} + \dots + e_{N} = N-2} (\frac{(N-2)!}{e_{1}! \dots e_{N}!}\times \frac{d_{1}!}{(d_{1} - e_{1} - 1)!} \dots \frac{d_{N}!}{(d_{N} - e_{N} - 1)!})
 = (N-2)! d_{1} \dots d_{N} \sum_{e_{i} \ge 0, e_{1} + \dots + e_{N} = N-2} ({}_{d_{1}-1}{\rm C}_{e_{1}} \times \dots \times {}_{d_{N}-1}{\rm C}_{e_{N}})

ここで、 \sum_{e_{i} \ge 0, e_{1} + \dots + e_{N} = N-2} ({}_{d_{1}-1}{\rm C}_{e_{1}} \times \dots \times {}_{d_{N}-1}{\rm C}_{e_{N}}) は、

 f(x) = (1 + x)^{d_{1}-1} \dots (1 + x)^{d_{N}-1} = (1 + x)^{d_{1} + \dots + d_{N} - N}

 e_{1} + \dots + e_{N} ( = N-2) 次の係数になっている。よって、求める値は

 (N-2)! d_{1} \dots d_{N} \lbrack x^{N-2} \rbrack f(x)

と求められる。 f(x) は形式的冪級数の pow を用いることで高速に計算できる。

 

解法 (2):形式的冪級数を使わない

もう少し式変形する。 S = d_{1} + \dots + d_{N} とおく。

 (N-2)! d_{1} \dots d_{N} \lbrack x^{N-2} \rbrack f(x)
 = d_{1} \dots d_{N} \frac{(S-N)!}{(S-2N+2)!}
 = d_{1} \dots d_{N} (S-N)(S-N-1)\dots(S-2N+3)

と求められる。よって  O(N) で計算できる。

 

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N;
    cin >> N;
    mint res = 1;
    long long S = 0;
    vector<long long> d(N);
    for (int i = 0; i < N; ++i) {
        cin >> d[i];
        res *= d[i];
        S += d[i];
    }
    for (int n = N; n <= N*2 - 3; ++n) res *= S - n;
    cout << res << endl;
}