好き系。本番解きに行ったけど、もっとサクッと通せればよかった。
問題概要
頂点の無向グラフ (頂点の番号は 1 〜 ) であって、以下の条件をすべて満たすものの個数を 1000000007 で割ったあまりを求めよ。
- 頂点 1 と 2 との間の最短距離が
制約
- <
考えたこと
こういう入力が少ないタイプの数え上げは大好き!!!
グラフとか盤面とか、割とでかいものを数え上げるのは高難易度でありがちな感じ。
まずは数え上げの方針を立てるところから。最短距離が となるパスをどうするか、という風に考えるとちょっと大変そう。最短経路が複数あるようなグラフとか重複なく数えるのは難しそう。グラフをどうまとめ上げるかをきちんと考えないと。
で、今回は「グラフの最短路木」で分類してあげるとよさそう。
こんな感じで各深さにおいてノードが何個あるか、みたいな感じでグラフを分類してあげて、それぞれを数え上げる感じ。状態のまとめ方としては
- dp[ n ][ d ][ a ] := N 個のうち n 個のノードを選んで (「2」は使わずにとっておく)、深さ d のグラフを作り、最終層のノード数が a となるグラフの作り方
という感じがよさそう。これを作っておくと、dp[ n ][ D-1 ][ a ] に対して、「2」と「その残り」をくっつけていく感じでできる:
- 各 a に対して、dp[ n ][ D-1 ][ a ] × × × を求めて足す
- は、ノード「2」を深さ のどこかにくっつける場合の数 (どこかとは繋がらないといけないので -1 している)
- は、ノード「2」以外と深さ の層についてどうするか (どことも繋がらないことが許される)
- は、「2」も含めた残りのノード同士のつながり
dp テーブルの方は
- 深さ d の層のノードたちはどれも深さ d-1 の層のノードとどこかとつながる (それより前の層とはつながってはいけない)
- 深さ d のノード同士はつながってもつながっていなくてもいい
というのを考慮して組むとできた。
- dp[n][d][a] += dp[n-a][d-1][b] × ×
という感じ。全体的の計算量は dp テーブルを作るところがボトルネックで、3 次元 + 遷移が 1 次元で、 になる。
#include <iostream> #include <cstring> using namespace std; const int MAX = 210000; const int MOD = 1000000007; long long fac[MAX],finv[MAX],inv[MAX]; void COMinit(){ fac[0] = fac[1] = 1; finv[0] = finv[1] = 1; inv[1] = 1; for(int i = 2; i < MAX; i++){ fac[i] = fac[i-1] * i % MOD; inv[i] = MOD - inv[MOD%i] * (MOD/i) % MOD; finv[i] = finv[i-1] * inv[i] % MOD; } } long long COM(int n,int k){ if(n < k) return 0; if (n < 0 || k < 0) return 0; return fac[n] * (finv[k] * finv[n-k] % MOD) % MOD; } long long modpow(long long a, long long n, long long mod) { long long res = 1; while (n > 0) { if (n & 1) res = res * a % mod; a = a * a % mod; n >>= 1; } return res; } long long N, D; long long dp[51][51][51]; long long rec(long long n, long long d, long long a) { if (d == 0) { if (n == 1 && a == 1) return 1; else return 0; } if (n < a + d) return 0; if (dp[n][d][a] != -1) return dp[n][d][a]; long long res = 0; for (long long b = 1; b <= n; ++b) { long long tmp = rec(n-a, d-1, b); long long ch = COM(N-n+a-1, a); long long bpow = (1LL<<b) - 1; bpow %= MOD; long long b2 = modpow(bpow, a, MOD); long long rem = modpow(2LL, a*(a-1)/2, MOD); tmp = tmp * ch % MOD * b2 % MOD * rem % MOD; (res += tmp) %= MOD; } return dp[n][d][a] = res; } int main() { COMinit(); cin >> N >> D; memset(dp, -1, sizeof(dp)); long long res = 0; for (long long n = 1; n <= N; ++n) { for (long long a = 1; a <= n; ++a) { long long tmp = rec(n, D-1, a); long long apow = (1LL<<a) - 1; apow %= MOD; long long a2 = apow * modpow((1LL<<a)%MOD, N-n-1, MOD) % MOD; long long rem = modpow(2LL, (N-n)*(N-n-1)/2, MOD); tmp = tmp * a2 % MOD * rem % MOD; (res += tmp) %= MOD; } } cout << res << endl; }