けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 213 G - Connectivity 2 (橙色, 600 点)

面白かった!! より一般化した問題 (グラフの頂点集合の任意の部分集合に対して、それらを連結にする辺の選び方の数え上げ) を考えた方が考えやすいね。

問題概要

頂点数  N、辺数  M の単純無向グラフ  G が与えられます。 G の辺集合の部分集合 ( 2^{M} 通りある) たちを考える。

各頂点  k = 1, 2, \dots, N-1 に対して、頂点  0 と頂点  k とが連結であるような部分集合が何個あるかを 998244353 で割ったあまりを求めよ。

制約

  •  1 \le N \le 17

連結グラフの数え上げ問題へ

頂点数の小さい無向グラフが与えられたときに、全頂点が連結であるような部分グラフ (辺集合の部分集合) の個数を数え上げる問題は割と典型な気がする。そしてそれを求める過程では、頂点集合のすべての部分集合  S に対しても、 S が連結であるような部分グラフ ( S 内部の辺のみを考える) の個数をすべて求めることになるのだ。

つまり、次の関数  f(S) が求められることになる (後述)。


 f(S) = 頂点集合  S ( \subset V) について、その内部にある辺たちの部分集合のうち、 S 全体が連結であるようなものの個数


そして今回の問題は  S として、とくに  S = (0, k) をとれば、よさそうだ。ただし注意点として、 f(S) はあくまで頂点集合  S 内部の辺のみを考えている。今回求めたいのは、もとのグラフ全体の辺たちの部分集合のうち、 S 全体を連結にするようなものの個数なのだ。

そこで、次のような場合分けをしよう。 S を含む連結成分で場合分けするのだ。

  •  S を含む連結成分をなる頂点の部分集合が  T であるとき、
  •  V - T に含まれる辺の本数を  e(V - T) とすると
  • そのような場合の数は  f(T) \times 2^{e(V - T)} 通り
    •  T V - T との間に辺を引くことは考えなくてよい

というふうに考えられる。これを各  T に対して総和をとればよいということになる。以上より、もとの問題は  f(T) を求める問題へと帰着された。

なお、上のことを実現するためには、各部分集合  S に対して  e(S) を計算する必要もある。愚直にやると  O(M 2^{N}) になる。今回はこれでも十分間に合うが、高速ゼータ変換を用いると  O(N 2^{N}) になる。

高速ゼータ変換

具体的には、頂点集合の部分集合  S に対して

  •  h(S) =  S のサイズが 2 であり、それらがグラフ  G のある辺の両端点になっているとき 1、それ以外のとき 0

という関数  h を定義すると

 e(S) = \sum_{T \subset S} h(T)

と表せることに注意しよう。これはまさに高速ゼータ変換を使える形になっている。下のコードのような感じで in-place にできる。ここで、初期状態の  e h を表しているものとする。これによって、 e(S) の計算は  O(N 2^{N}) となる。

// 初期状態の e は h を表すものとする
// e に対して in-place に累積和をとるような更新をする
for (int i = 0; i < N; ++i) {
    for (int S = 0; S < (1<<N); ++S) {
        if (S & (1 << i)) {
             e[S] += e[S ^ (1 << i)];
        }
    }
}

 f(S) を求める

それでは、次の  f(S) を求めていきます。


 f(S) = 頂点集合  S ( \subset V) について、その内部にある辺たちの部分集合のうち、 S 全体が連結であるようなものの個数


このように、グラフの頂点集合の部分集合に対する何かを計算するときには、1 つ頂点  v を選んでその頂点  v に対する挙動で場合分するという方針がしばしば有効なイメージがある (例:彩色数)。

ここでも  v \in S を一つとる。また、 f(S) を求めるためには、逆に条件を満たさないものを数え上げて引くことにする。条件を満たさないとき、 v を含む連結成分は  S の真部分集合のいずれかになる。それを  T としたときは

 f(T) \times 2^{e(S - T)} 通り

となる。よって

 \displaystyle f(S) = 2^{e(S)} - \sum_{T \subsetneq S, v \in T} f(T) \times 2^{e(S - T)}

と求められることがわかった (除原理)。これは、 f(S) O(3^{N}) な bitDP で計算できることを意味している!!! あとは丁寧に実装していけば OK。

なお、Subset Convolution という超高度な技を使えば  O(N^{2} 2^{N}) で計算できるみたい (あまりよく分かってない...)。

コード

計算量は

  •  e(S) を求める: O(N 2^{N}) (高速ゼータ変換)
  •  f(S) を求める: O(3^{N}) (Subset Convolution を用いれば  O(N^{2} 2^{N})

となる。

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;





int main() {
    // 入力
    int N, M;
    cin >> N >> M;
    vector<vector<bool>> G(N, vector<bool>(N, false));
    for (int i = 0; i < M; ++i) {
        int a, b;
        cin >> a >> b;
        --a, --b;
        G[a][b] = G[b][a] = true;
    }

    // 2 の冪乗を確保しておく
    vector<mint> two(N*N + 1, 1);
    for (int i = 0; i < N*N; ++i) two[i+1] = two[i] * 2;

    // e(S) を求める
    vector<long long> e(1<<N, 0);
    for (int u = 0; u < N; ++u)  {
        for (int v = 0; v < N; ++v) {
            if (G[u][v]) {
                int S = (1<<u) | (1<<v);
                e[S] = 1;
            }
        }
    }
    for (int i = 0; i < N; ++i) {
        for (int S = 0; S < (1<<N); ++S) {
            if (S & (1<<i)) {
                e[S] += e[S ^ (1<<i)];
            }
        }
    }

    // f(S) を求める
    vector<mint> f(1<<N, 0);
    for (int S = 0; S < (1<<N); ++S) {
        if (S == 0) continue;
        
        // S に含まれる頂点 v を 1 つ選ぶ
        int v = -1;
        for (v = 0; v < N; ++v) if (S & (1<<v)) break;

        // 全体
        f[S] = two[e[S]];
        
        // S の部分集合を走査していく
        for (int T = S - 1; T >= 0; --T) {
            T &= S;

            // T が v を含む場合のみ
            if (T & (1 << v)) {
                f[S] -= f[T] * two[e[S ^ T]];
            }
        }
    }

    // 集計する
    int V = (1<<N) - 1;
    for (int k = 1; k < N; ++k) {
        mint res = 0;
        for (int S = 0; S < (1<<N); ++S) {
            if (!(S & (1<<0))) continue;
            if (!(S & (1<<k))) continue;
            res += f[S] * two[e[V ^ S]];
        }
        cout << res << endl;
    }   
}