けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 133 E - Virus Tree 2 (500 点)

木の走査って

  • 根の方から情報を配っていく
  • 子ノードたちの情報を引っ張ってくる (いわゆる木 DP)

という二つの方向性があって、状況に応じてうまいこと使い分けるとよいイメージがある。

問題へのリンク

問題概要

 N 頂点の木があたえられる。木の各頂点を  K 色に塗り分ける方法のうち、どの距離が 2 以下の二頂点も異なる色になるようにする方法が何通りあるか、1000000007 で割ったあまりを求めよ。

制約

  •  1 \le N, K \le 10^{5}

考えたこと

こういうの一目、木 DP なのだけど、ちょっと手を動かしてみると


なんか頂点を見ていくごとに、各頂点について「自由度」が決まって、その積が答えになる


みたいな雰囲気をかすかに感じた。なのでその直感を信じると、子ノードから頑張って情報を引っ張って木 DP する、というよりは単純に各ノードを見ていくだけで答えが求まりそう。

詰めてみる

根をとりあえず決める。

  • 根はどう塗ってもいいので  K 通り
  • 根の子頂点たちは、根と同じ色ではダメで、さらに互いにどの頂点も色が異なっていなければならないので、その個数を  c とすると、 {}_{K-1}{\rm P}_{c} 通りになる
  • それより深いところにある頂点 v についてその子頂点たちは、v とその親とは色が異なっていなければならず、さらに互いにどの頂点も色が異なっていなければならないので、その個数を  c とすると、 {}_{K-2}{\rm P}_{c} 通りになる

という感じで、根から下へ下がっていって掛け算していけば答えになる。

modint

以下の実装では、

を用いている。

#include <iostream>
#include <vector>
using namespace std;

// modint (1000000007 で割ったあまりを扱う構造体)
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

// 二項係数ライブラリ
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;
BiCoef<mint> bc;

int N, K;
vector<vector<int>> G;

void rec(int v, int p, mint &res, int depth) {
    int chs = 0;
    for (auto ch : G[v]) {
        if (ch == p) continue;
        ++chs;
        rec(ch, v, res, depth+1);
    }
    if (depth == 0) res *= bc.com(K-1, chs) * bc.fact(chs);
    else res *= bc.com(K-2, chs) * bc.fact(chs);
}

int main() {
    bc.init(110000);
    cin >> N >> K;
    G.assign(N, vector<int>());
    for (int i = 0; i < N-1; ++i) {
        int a, b; cin >> a >> b; --a, --b;
        G[a].push_back(b);
        G[b].push_back(a);
    }
    mint res = 1;
    rec(0, -1, res, 0);
}