こういうの素早く解けるようになりたいね。
いわゆる「トポロジカルソート順の数え上げ」という高難易度でたまに見るパターンの問題。
問題概要
頂点 辺の無向グラフが与えられる。ここで、辺 が全域木を形成していることが保証されている。
各辺に の重みをつける 通りの方法のうち、 の辺たちが最小全域木をなすものが何通りあるか、109 + 7 で割った余りを求めよ。
制約
考えたこと: トポロジカルソート数え上げへの帰着
まずは最小全域木の特徴づけが
- MST をなす辺以外の各辺 に対して、その辺と MST とでサイクルが 1 つ形成されるが、その中で の重みが最大である
で与えられることを思い出す (ここ参照)。そうすると、この問題は に対して重み を割り当てる方法のうち、
- , となる について、 を満たす必要がある
というタイプの制約がひたすら並ぶようなトポロジカルソート順を走査する問題になる。なんか、こういう風に「制約を書き出すとトポロジカルソート走査問題に帰着するタイプの問題」は、下のような問題でも見た。
トポロジカルソート走査・数え上げの指針
一般には #P-complete な問題だが、今回数え上げたい有向グラフは
- 二部グラフであって、左ノードから右ノードへの順序がある形状
- 左ノード数が 個
という特徴がある。そして左ノードの順序を固定してみると、例えば i < j であって、i -> a、j -> a という順序がある部分について、i -> a は取り除いてよくて、下図 (editorial 引用) のような形の有向グラフのトポロジカルソートを走査することになる。
これは、比較的容易にできる。ポイントとしては右から順に並べる。この順序を間違うと、よくわからなくなる。 にくっついているノードの個数をそれぞれ とする (それぞれ 1, 2, ..., N-1 も含むようにする) と、 を順列として
- まず、 と 個とを並べる 通り
- 次に、 を の左において、 個を挿入することを考えると、 通り
- 次に、 を の左において、 個を挿入することを考えると、 通り
- ...
これを掛け算することになる。さらに本当に求めるものは「個数」ではなく「重みの総和」であるが、これは頑張る。まず
白いボール 個と、黒いボール 個あるときに、その並びが 個あって、黒いボールの index の総和が であったとする。ここから、各箇所にボールを挿入したときの の変化は、
- 挿入するのが黒のとき、 は のままで、 は になる
- 挿入するのが白のとき、 は になって、 は になる
ということに注目する。 の変化がとてもわかりづらいけど、一度知ってしまえば良さそうな感じ。理由としては
- 各黒ボールが右にずれる分を除外したときに、
- index が となっている各黒ボールにつき、 が右にずれるような の左側の挿入箇所は 通りあって、それらそれぞれについて が加算されるイメージ、したがって全体としては が加算される
ということで、合計で になる。
というわけでまとめると、
すでにある黒ボール 個と白ボール 個の並びが、 通りあって、黒ボールの index の合計が であるときに、黒ボール 1 個と、白ボール 個とを加えたとき、
- は に
- は に
- は に
- は に
なる
bitDP へ
ここまでは左ノード 1, 2, ..., N-1 の順序を固定したときの話だが、順序をいじるなら bitDP でできそう。
- cdp[ bit ] := bit で表されたノードと、そこから飛び出ているノードを並べる場合の数
- sdp[ bit ] := bit で表されたノードと、そこから飛び出ているノードたちの並びすべてについての bit で表されたノードの index 和の総和
としておく。このとき、更新は、
- con[ bit ] := bit に含まれるのが何個か
- 1, 2, ..., N-1 のうち、bit で表される頂点たちから繋がっている頂点数を cnt[ bit ] (bit の個数も含む)
としておいて、
- bit に含まれない要素 i を足して nbit = bit | (1<<i) として
- cdp[nbit] = (cnt[nbit])(cnt[nbit ] - 1)...(cnt[bit] + 2) × cdp[bit]
- sdp[nbit] = (cnt[nbit] + 1)(cnt[nbit])...(cnt[bit] + 3) × sdp[bit] + (con[bit] + 1)(cnt[nbit])(cnt[nbit ] - 1)...(cnt[bit] + 2) × cdp[bit]
で求められる。このあたりの計算量は でできる。
cnt[ bit ] の求め方
残るタスクは、cnt[ bit ] を求めること。すなわち bit のいずれかの要素から出ている頂点集合のサイズを求めること。ここを愚直にやっては かかっておそらく間に合わない。
これは、
- f( S ) := 「辺 i = N, N+1, ..., M に対して、MST に対して辺 i を付け加えてできるサイクルに含まれる辺 i 以外の辺 (1 以上 N-1 以下) の集合が S である」ような i の個数
としてあげて、
- g( S ) = f( T )
- cnt[ S ] = M - (N-1) - g( S の補集合 ) + |S|
となることがわかるので、これは高速ゼータ変換によって求めることができる。
コード
以上をまとめる。計算量は、
- 高速ゼータ変換:
- bitDP:
#include <iostream> #include <vector> #include <bitset> using namespace std; template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) v += MOD; } constexpr int getmod() { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b; swap(a, b); u -= t * v; swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept { return is >> x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept { if (n == 0) return 1; auto t = modpow(a, n / 2); t = t * t; if (n & 1) t = t * a; return t; } }; // 二項係数ライブラリ template<class T> struct BiCoef { vector<T> fact_, inv_, finv_; constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) { int MOD = fact_[0].getmod(); for(int i = 2; i < n; i++){ fact_[i] = fact_[i-1] * i; inv_[i] = -inv_[MOD%i] * (MOD/i); finv_[i] = finv_[i-1] * inv_[i]; } } constexpr T com(int n, int k) const noexcept { if (n < k || n < 0 || k < 0) return 0; return fact_[n] * finv_[k] * finv_[n-k]; } constexpr T fact(int n) const noexcept { if (n < 0) return 0; return fact_[n]; } constexpr T inv(int n) const noexcept { if (n < 0) return 0; return inv_[n]; } constexpr T finv(int n) const noexcept { if (n < 0) return 0; return finv_[n]; } }; const int MAX = 201010; const int MOD = 1000000007; using mint = Fp<MOD>; int N, M; using pint = pair<int,int>; using Graph = vector<vector<pint> >; Graph G; int dfs(int v, int p, int goal) { if (v == goal) return 0; int res = -1; for (auto e : G[v]) { if (e.first == p) continue; int tmp = dfs(e.first, v, goal); if (tmp != -1) { if (res == -1) res = tmp | (1<<e.second); else res |= tmp | (1<<e.second); } } return res; } int main() { BiCoef<mint> bc(MAX); // 入力と、MST 以外の辺が MST と作るサイクルたち cin >> N >> M; G.assign(N, vector<pint>()); for (int i = 0; i < N-1; ++i) { int a, b; cin >> a >> b; --a, --b; G[a].push_back({b, i}); G[b].push_back({a, i}); } vector<int> cmp(1<<(N-1), 0); for (int i = N-1; i < M; ++i) { int a, b; cin >> a >> b; --a, --b; int S = dfs(a, -1, b); cmp[S]++; } // 高速ゼータ変換 for (int i = 0; i < N-1; ++i) for (int bit = 0; bit < 1<<(N-1); ++bit) if (bit & (1<<i)) cmp[bit] += cmp[bit ^ (1<<i)]; vector<int> cnt(1<<(N-1), 0); for (int bit = 0; bit < 1<<(N-1); ++bit) { cnt[bit] = M-(N-1) - cmp[(1<<(N-1))-1 - bit] + __builtin_popcount(bit); } // DP vector<mint> cdp(1<<(N-1), 0), sdp(1<<(N-1), 0); cdp[0] = 1; for (int bit = 0; bit < 1<<(N-1); ++bit) { long long con = __builtin_popcount(bit); for (int i = 0; i < N-1; ++i) { if (bit & (1<<i)) continue; int nbit = bit | (1<<i); cdp[nbit] += bc.fact(cnt[nbit]-1) * bc.finv(cnt[bit]) * cdp[bit]; sdp[nbit] += bc.fact(cnt[nbit]) * bc.finv(cnt[bit]+1) * sdp[bit] + bc.fact(cnt[nbit]-1) * bc.finv(cnt[bit]) * cdp[bit] * (con + 1); } } cout << sdp[(1<<(N-1)) - 1] << endl; }