けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

diverta 2019 F - Edge Ordering (銅色, 1200 点)

こういうの素早く解けるようになりたいね。
いわゆる「トポロジカルソート順の数え上げ」という高難易度でたまに見るパターンの問題。

問題へのリンク

問題概要

 N 頂点  M 辺の無向グラフが与えられる。ここで、辺  1, 2, \dots, N-1 が全域木を形成していることが保証されている。

各辺に  1, 2, \dots, M の重みをつける  M! 通りの方法のうち、 1, 2, \dots, N-1 の辺たちが最小全域木をなすものが何通りあるか、109 + 7 で割った余りを求めよ。

制約

  •  2 \le N \le 20

考えたこと: トポロジカルソート数え上げへの帰着

まずは最小全域木の特徴づけが

  • MST をなす辺以外の各辺  e に対して、その辺と MST とでサイクルが 1 つ形成されるが、その中で  e の重みが最大である

で与えられることを思い出す (ここ参照)。そうすると、この問題は  1, 2, \dots, M に対して重み  w(1), w(2), \dots, w(M) を割り当てる方法のうち、

  •  1 \le i \le N-1,  N \le j \le M となる  i, j について、 w(i) \lt w(j) を満たす必要がある

というタイプの制約がひたすら並ぶようなトポロジカルソート順を走査する問題になる。なんか、こういう風に「制約を書き出すとトポロジカルソート走査問題に帰着するタイプの問題」は、下のような問題でも見た。

atcoder.jp

トポロジカルソート走査・数え上げの指針

一般には #P-complete な問題だが、今回数え上げたい有向グラフは

  • 二部グラフであって、左ノードから右ノードへの順序がある形状
  • 左ノード数が  N-1

という特徴がある。そして左ノードの順序を固定してみると、例えば i < j であって、i -> a、j -> a という順序がある部分について、i -> a は取り除いてよくて、下図 (editorial 引用) のような形の有向グラフのトポロジカルソートを走査することになる。

これは、比較的容易にできる。ポイントとしては右から順に並べる。この順序を間違うと、よくわからなくなる。 1, 2, \dots, N-1 にくっついているノードの個数をそれぞれ  a_{1}, a_{2}, \dots, a_{N-1} とする (それぞれ 1, 2, ..., N-1 も含むようにする) と、 P を順列として

  • まず、 N-1 a_{N-1} - 1 個とを並べる  P(a_{N-1} - 1, a_{N-1} - 1) 通り
  • 次に、 N-2 N-1 の左において、 a_{N-2} - 1 個を挿入することを考えると、 P(a_{N-2} + a_{N-1} - 1, a_{N-2} - 1) 通り
  • 次に、 N-3 N-2 の左において、 a_{N-3} 個を挿入することを考えると、 P(a_{N-3} + a_{N-2} + a_{N-1} - 1, a_{n-3} - 1) 通り
  • ...

これを掛け算することになる。さらに本当に求めるものは「個数」ではなく「重みの総和」であるが、これは頑張る。まず


白いボール  w 個と、黒いボール  b 個あるときに、その並びが  c 個あって、黒いボールの index の総和が  s であったとする。ここから、各箇所にボールを挿入したときの  s の変化は、

  • 挿入するのが黒のとき、 c c のままで、 s s + (b + 1)c になる
  • 挿入するのが白のとき、 c (b+w+1)c になって、 s (b+w+2)s になる

ということに注目する。 s の変化がとてもわかりづらいけど、一度知ってしまえば良さそうな感じ。理由としては

  • 各黒ボールが右にずれる分を除外したときに、 (b + w + 1)s
  • index が  i となっている各黒ボールにつき、 i が右にずれるような  i の左側の挿入箇所は  i 通りあって、それらそれぞれについて  i が加算されるイメージ、したがって全体としては  s が加算される

ということで、合計で  (b + w + 2)s になる。

というわけでまとめると、


すでにある黒ボール  b 個と白ボール  w 個の並びが、 c 通りあって、黒ボールの index の合計が  s であるときに、黒ボール 1 個と、白ボール  a-1 個とを加えたとき、

  •  b b+1
  •  w w + a-1
  •  c (b+w+a-1)(b+w+a-2)...(b+w+1)c
  •  s (b+w+a)(b+w+a-1)...(b+w+2)s + (b+1)(b+w+a-1)(b+w+a-2)...(b+w+1)c

なる


bitDP へ

ここまでは左ノード 1, 2, ..., N-1 の順序を固定したときの話だが、順序をいじるなら bitDP でできそう。

  • cdp[ bit ] := bit で表されたノードと、そこから飛び出ているノードを並べる場合の数
  • sdp[ bit ] := bit で表されたノードと、そこから飛び出ているノードたちの並びすべてについての bit で表されたノードの index 和の総和

としておく。このとき、更新は、

  • con[ bit ] := bit に含まれるのが何個か
  • 1, 2, ..., N-1 のうち、bit で表される頂点たちから繋がっている頂点数を cnt[ bit ] (bit の個数も含む)

としておいて、

  • bit に含まれない要素 i を足して nbit = bit | (1<<i) として
  • cdp[nbit] = (cnt[nbit])(cnt[nbit ] - 1)...(cnt[bit] + 2) × cdp[bit]
  • sdp[nbit] = (cnt[nbit] + 1)(cnt[nbit])...(cnt[bit] + 3) × sdp[bit] + (con[bit] + 1)(cnt[nbit])(cnt[nbit ] - 1)...(cnt[bit] + 2) × cdp[bit]

で求められる。このあたりの計算量は  O(N2^{N}) でできる。

cnt[ bit ] の求め方

残るタスクは、cnt[ bit ] を求めること。すなわち bit のいずれかの要素から出ている頂点集合のサイズを求めること。ここを愚直にやっては  O(MN2^{N}) かかっておそらく間に合わない。

これは、

  • f( S ) := 「辺 i = N, N+1, ..., M に対して、MST に対して辺 i を付け加えてできるサイクルに含まれる辺 i 以外の辺 (1 以上 N-1 以下) の集合が S である」ような i の個数

としてあげて、

  • g( S ) =  \sum_{T \subseteq S} f( T )
  • cnt[ S ] = M - (N-1) - g( S の補集合 ) + |S|

となることがわかるので、これは高速ゼータ変換によって求めることができる。

コード

以上をまとめる。計算量は、

  • 高速ゼータ変換:  O(N2^{N})
  • bitDP:  O(N2^{N})
#include <iostream>
#include <vector>
#include <bitset>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

// 二項係数ライブラリ
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MAX = 201010;
const int MOD = 1000000007;
using mint = Fp<MOD>;

int N, M;
using pint = pair<int,int>;
using Graph = vector<vector<pint> >;
Graph G;

int dfs(int v, int p, int goal) {
    if (v == goal) return 0;
    int res = -1;
    for (auto e : G[v]) {
        if (e.first == p) continue;
        int tmp = dfs(e.first, v, goal);
        if (tmp != -1) {
            if (res == -1) res = tmp | (1<<e.second);
            else res |= tmp | (1<<e.second);
        }
    }
    return res;
}

int main() {
    BiCoef<mint> bc(MAX);
    
    // 入力と、MST 以外の辺が MST と作るサイクルたち
    cin >> N >> M;
    G.assign(N, vector<pint>());
    for (int i = 0; i < N-1; ++i) {
        int a, b; cin >> a >> b; --a, --b;
        G[a].push_back({b, i});
        G[b].push_back({a, i});
    }
    vector<int> cmp(1<<(N-1), 0);
    for (int i = N-1; i < M; ++i) {
        int a, b; cin >> a >> b; --a, --b;
        int S = dfs(a, -1, b);
        cmp[S]++;
    }

    // 高速ゼータ変換
    for (int i = 0; i < N-1; ++i)
        for (int bit = 0; bit < 1<<(N-1); ++bit)
            if (bit & (1<<i))
                cmp[bit] += cmp[bit ^ (1<<i)];
    vector<int> cnt(1<<(N-1), 0);
    for (int bit = 0; bit < 1<<(N-1); ++bit) {
        cnt[bit] = M-(N-1) - cmp[(1<<(N-1))-1 - bit] + __builtin_popcount(bit);
    }
    
    // DP
    vector<mint> cdp(1<<(N-1), 0), sdp(1<<(N-1), 0);
    cdp[0] = 1;
    for (int bit = 0; bit < 1<<(N-1); ++bit) {
        long long con = __builtin_popcount(bit);
        for (int i = 0; i < N-1; ++i) {
            if (bit & (1<<i)) continue;
            int nbit = bit | (1<<i);
            cdp[nbit] += bc.fact(cnt[nbit]-1) * bc.finv(cnt[bit]) * cdp[bit];
            sdp[nbit] += bc.fact(cnt[nbit]) * bc.finv(cnt[bit]+1) * sdp[bit]
                + bc.fact(cnt[nbit]-1) * bc.finv(cnt[bit]) * cdp[bit] * (con + 1);
        }
    }
    cout << sdp[(1<<(N-1)) - 1] << endl;
}