けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 152 F - Tree and Constraints (青色, 600 点)

包除原理をとても素朴な状態で問いかける問題!!!

問題へのリンク

問題概要

 N 頂点のツリーが与えられる。ツリーの各辺を白色または黒色に塗る  2^{N-1} 通りの方法のうち、以下の  M 個の制約をすべて満たすものの個数を求めよ。

  •  i 個目の制約は 2 頂点  a, b が指定され、 a b とを結ぶパスには黒辺が含まれるようにする

制約

  •  2 \le N \le 50
  •  1 \le M \le 20

考えたこと

まず一目思うこととして

  • パス a-b 中のうち一つ以上を黒辺にせよ

という制約はとても扱いづらい!!!!!しかも「扱いづらい制約が  M 個あってすべて満たす必要がある」という状況。こういうのは包除原理でうまくいくことがよくある。こんな風にする:

(制約をすべて満たすもの)
= (全部の塗り方)
- (制約のうち 1 個以上を満たさないもの)
+ (制約のうち 2 個以上を満たさないもの)
- (制約のうち 3 個以上を満たさないもの)
+ ...

たとえば  M = 3 のときは、こうなる。

(制約 1, 2, 3 をすべて満たすもの)
= (全部の塗り方)
- 制約 1 を満たさないもの
- 制約 2 を満たさないもの
- 制約 3 を満たさないもの
+ 制約 1, 2 を満たさないもの
+ 制約 2, 3 を満たさないもの
+ 制約 3, 1 を満たさないもの
- 制約 1, 2, 3 を満たさないもの

満たさない制約番号の指定の仕方は  2^{M} 通りある。それぞれについて、以下のようにすればよい。

  • 満たさない制約番号の集合を  S = {i_{1}, \dots, i_{K}} としたとき ( K 個)
  • これらの制約を満たさない場合の数を  f(S) として
  •  {\rm res} += (-1)^{K}f(S) とする

よって、それぞれの  S についての  f(S) の計算に要する時間を  O(X) としたならば、計算量は  O(X2^{M}) となる。次に  f(S) の計算の仕方を考える。

DFS

 i_{1}, \dots, i_{K} 番目の制約を満たさないものを数え上げることを考える。これはつまり、

  •  K 本のパスが与えられるので
  • それらのパス上の辺はすべて白色となるようにする

ということになる。よって、 K 本のパスに含まれるような辺の本数を  p としたならば、それらは白色に塗った上で、残りの  N-1-p 本については自由に着色してよいので  2^{N-1-p} 通りとなる。よって問題は「パス中の辺を列挙する」という問題になった。

これは DFS でできる。ただしいつもの DFS と違って、パスの復元をする必要があるので少し工夫が必要になる。

前処理

実装上は  M 個のパスそれぞれに対して、パスに含まれる辺集合を列挙しておくとよさそう。さらに、辺の本数はたかだか 50 本なので、辺集合は long long 型の変数として管理できる。

全体の計算量は、包除原理の各ステップについて、  O(NM) だが、long long 型変数の OR 演算で表せることから実質  O(M) という感じになる。なのでちゃんとした計算量は  O(NM2^{M}) だが、実質  O(N2^{M}) という感じになる。

#include <iostream>
#include <vector>
using namespace std;
using Edge = pair<int, int>;
using Graph = vector<vector<Edge>>;

int N, M;
Graph G;
vector<int> us, vs;

bool rec(int v, int p, int target, long long &path) {
    if (v == target) {
        path = 0;
        return true;
    }
    for (auto e : G[v]) {
        if (e.first == p) continue;
        if (rec(e.first, v, target, path)) {
            path |= (1LL<<e.second);
            return true;
        }
    }
    return false;
}

long long solve() {
    vector<long long> paths(M, 0);
    for (int i = 0; i < M; ++i) rec(us[i], -1, vs[i], paths[i]);
    long long res = 0;
    for (long long bit = 0; bit < (1<<M); ++bit) {
        long long val = 0;
        for (int i = 0; i < M; ++i) if (bit & (1LL<<i)) val |= paths[i];
        long long remnum = N-1 - __builtin_popcountl(val);
        if (__builtin_popcountl(bit) % 2 == 0) res += 1LL<<remnum;
        else res -= 1LL<<remnum;
    }
    return res;
}

int main() {
    cin >> N;
    G.assign(N, vector<Edge>());
    for (int i = 0; i < N-1; ++i) {
        int x, y; cin >> x >> y; --x, --y;
        G[x].emplace_back(y, i);
        G[y].emplace_back(x, i);
    }
    cin >> M;
    us.resize(M), vs.resize(M);
    for (int i = 0; i < M; ++i) {
        cin >> us[i] >> vs[i];
        --us[i], --vs[i];
    }
    cout << solve() << endl;
}