けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 101 E - Ribbons on Tree (900 点)

すごく典型的な「二乗の木 DP」!!!!!
そして包除原理との組み合わせ。

問題へのリンク

問題概要

 N を偶数とする。
 N 頂点の木が与えられる。 N 頂点を  \frac{N}{2} 組の 2 つペアにする方法のうち、各ペアを結ぶパスをすべて考えたときに全辺が被覆されるようなものの個数を 1000000007 で割ったあまりを求めよ。

制約

  •  2 \le N \le 5000

考えたこと

めっちゃ面白そう!!!!!!!!!!
まず思ったのは、「すべての辺について、その辺が被覆される」という条件は扱いづらい。

一方、例えば「辺 e が被覆されないもの」は比較的数えやすそう。辺 e によって木は 2 つの部分にわかれ、それぞれのマッチングを数え上げて掛け算すればよさそう。

というわけでいかにも包除原理っぽい。 k 辺以上が被覆されない場合の数を求められれば包除原理が使える。そしてそれは木 DP で求められそう。

  • dp[ v ][ num ][ k ] := 頂点 v を根とした部分木において、k 個の辺を削除して、v を含む頂点が num 個あるような状態にする場合の数

とするとよさそう。ただしこのままでは計算時間的に厳しい。k のところは偶奇のみでよい。よって

  • dp[ v ][ num ][ 0 or 1 ] := 頂点 v を根とした部分木において、偶数 or 奇数個の辺を削除して、v を含む頂点が num 個であるような状態にする場合の数

とすれば木 DP が動く。一見 num について二重ループになって  O(N^{3}) に見えるけど実は  O(N^{2}) になってる定期。

#include <iostream>
#include <vector>
#include <cstring>
using namespace std;

// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

// 二項係数ライブラリ
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;
BiCoef<mint> bc;


int N;
vector<vector<int> > G;

// n 要素のマッチングの個数
vector<mint> pre;

// 木 DP
int size[5100];
mint dp[5100][5100][2];
mint sdp[5100][5100][2];

void rec(int v, int p = -1) {
    vector<int> chs;
    size[v] = 1;
    for (auto ch : G[v]) {
        if (ch == p) continue;
        rec(ch, v);
        chs.push_back(ch);
        size[v] += size[ch];
    }

    // サブ DP (初期化は i <= chs.size() で止めるのがポイント)
    for (int i = 0; i <= chs.size(); ++i)
        for (int j = 0; j <= size[v]; ++j)
            sdp[i][j][0] = sdp[i][j][1] = 0;
    sdp[0][0][0] = 1;
    int cur = 0;
    for (int i = 0; i < chs.size(); ++i) {
        int ch = chs[i];
        for (int num = 0; num <= cur; ++num) {
            for (int par = 0; par <= 1; ++par) {
                for (int num2 = 0; num2 <= size[ch]; ++num2) {
                    for (int par2 = 0; par2 <= 1; ++par2) {
                        // 切る
                        sdp[i+1][num][(par+par2+1)%2] += sdp[i][num][par] * dp[ch][num2][par2] * pre[num2];

                        // 切らない
                        sdp[i+1][num+num2][(par+par2)%2] += sdp[i][num][par] * dp[ch][num2][par2];
                    }
                }
            }
        }
        cur += size[ch];
    }

    // まとめ
    for (int num = 0; num <= size[v]; ++num) {
        for (int par = 0; par <= 1; ++par) {
            dp[v][num+1][par] = sdp[chs.size()][num][par];
        }
    }
}

int main() {
    bc.init(5100);

    // 入力
    cin >> N;
    G.assign(N, vector<int>());
    for (int i = 0; i < N-1; ++i) {
        int a, b; cin >> a >> b; --a, --b;
        G[a].push_back(b);
        G[b].push_back(a);
    }

    // 前処理
    pre.assign(N+1, 0);
    for (int n = 0; n <= N; n += 2) {
        pre[n] = bc.fact(n) * bc.finv(n/2) / modpow(mint(2), n/2);
    }

    // 木 DP
    memset(dp, 0, sizeof(dp));
    rec(0);

    // 集計
    mint res = 0;
    for (int num = 0; num <= N; ++num) {
        res += dp[0][num][0] * pre[num];
        res -= dp[0][num][1] * pre[num];
    }
    cout << res << endl;
}