この問題をキッカケに準完全二分木のライブラリを拡充した!
問題概要
頂点数 の根付き木が与えられる。頂点番号は である。各頂点 (> 2) について、親頂点は である。
この根付き木において、頂点 からの距離が であるような頂点の個数を求めよ。
( ケース与えられる)
制約
考えたこと
次の総和を求めればよい。
- 頂点 からの深さが である頂点の個数
- 頂点 の親から見て、 を子孫に持たない側の子頂点からの深さが である頂点の個数
- 頂点 の親の親から見て、 を子孫に持たない側の子頂点からの深さが である頂点の個数
- ...
最終的に、根に到達する前に、頂点 からの距離が である頂点があれば、それもカウントする。
計算量は、 から根までの頂点数が で抑えられることから、全体の計算量は となる。
コード
この機会にライブラリとして整備した!
#include <bits/stdc++.h> using namespace std; // Find out of Strongly Balanced Binary Tree (N <= 10^18) // the vertex number is 1-indexed (root = 1) template<class mint> struct FindOutBinaryTree { // input data long long N; // main results vector<mint> depth_table; // depth_table[d] := # of nodes whose distance from root is d vector<mint> distance_table; // distance_tabls[l] := # of paths whose length is l // results of perfect binary trees vector<vector<mint>> perfect_depth_table, perfect_distance_table; // constructor FindOutBinaryTree() {} FindOutBinaryTree(long long n, bool build_dt = true) : N(n) { if (build_dt) init(n); } void set(long long n) { N = n; } void init(long long n) { N = n; int D = 0; while (n) { ++D, n /= 2; } findout_perfect_binary_tree(D); findout_binary_tree(); } // preprocess of perfect binary trees void findout_perfect_binary_tree(int D) { auto pre = [&](auto self, long long d) -> vector<mint> { if (d == 0) { perfect_depth_table[d] = vector<mint>({mint(1)}); return perfect_distance_table[d] = vector<mint>({mint(0), mint(1)}); } vector<mint> depth(d+1, 0), distance(d*2+2, 0); for (int i = 0; i <= d; ++i) depth[i] = mint(1LL<<i); for (int i = 0; i <= d; ++i) distance[i+1] += mint(1LL<<i); for (int i = 1; i <= d; ++i) for (int j = 1; j <= d; ++j) { distance[i+j+1] += mint(1LL<<(i-1)) * mint(1LL<<(j-1)); } const auto &left = self(self, d-1); for (int i = 0; i < left.size(); ++i) distance[i] += left[i] * 2; perfect_depth_table[d] = depth; return perfect_distance_table[d] = distance; }; perfect_depth_table.resize(D+1); perfect_distance_table.resize(D+1); pre(pre, D); } // get left depth and right depth pair<long long, long long> get_depth(long long v) { long long left_depth = 0, right_depth = 0; long long left = v, right = v; while (left * 2 <= N) ++left_depth, left = left * 2; while (right * 2 + 1 <= N) ++right_depth, right = right * 2 + 1; return {left_depth, right_depth}; } // find out the binary tree (size N) void findout_binary_tree() { auto rec = [&](auto self, long long v) -> pair<vector<mint>, vector<mint>> { vector<mint> depth, distance; if (v > N) return {depth, distance}; // examine the depth of left subtree and right subtree auto [ld, rd] = get_depth(v); if (ld == rd) return {perfect_depth_table[ld], perfect_distance_table[rd]}; // search the left subtree and right subtree auto [left_depth, left_distance] = self(self, v * 2); auto [right_depth, right_distance] = self(self, v * 2 + 1); depth.assign(max((int)left_depth.size(), (int)right_depth.size()) + 1, 0); distance.assign((int)left_depth.size() + (int)right_depth.size() + 2, 0); // update depth[0] = distance[1] = 1; for (int d = 0; d < (int)left_depth.size(); ++d) { depth[d + 1] += left_depth[d]; distance[d + 2] += left_depth[d]; } for (int d = 0; d < (int)right_depth.size(); ++d) { depth[d + 1] += right_depth[d]; distance[d + 2] += right_depth[d]; } for (int d1 = 0; d1 < (int)left_depth.size(); ++d1) { for (int d2 = 0; d2 < (int)right_depth.size(); ++d2) { distance[d1 + d2 + 3] += left_depth[d1] * right_depth[d2]; } } for (int l = 1; l < (int)left_distance.size(); ++l) { distance[l] += left_distance[l]; } for (int l = 1; l < (int)right_distance.size(); ++l) { distance[l] += right_distance[l]; } return {depth, distance}; }; auto [depth, distance] = rec(rec, 1); depth_table = depth; distance_table = distance; } // the number of nodes whose depth from v is d (v is 1-indexed) mint get_num_of_the_depth(long long v, long long d) { if (v <= 0 || v > N || d < 0) return mint(0); auto [left_depth, right_depth] = get_depth(v); if (left_depth < d) return mint(0); else if (right_depth >= d) return mint(1LL << d); else return mint(N - (v << d) + 1); } // the number of nodes whose distance from v is d (v is 1-indexed) mint get_num_of_the_distance(long long v, long long d) { if (v <= 0 || v > N) return mint(0); mint res = get_num_of_the_depth(v, d); for (long long i = 1; i <= d; ++i) { if (v == 1) break; if (i == d) { res += 1; break; } long long v2 = v / 2; if (v == v2 * 2 + 1) res += get_num_of_the_depth(v2 * 2, d - i - 1); else res += get_num_of_the_depth(v2 * 2 + 1, d - i - 1); v = v2; } return res; } }; void ABC_321_E() { auto solve = [&]() -> void { long long N, X, D; cin >> N >> X >> D; FindOutBinaryTree<long long> fbt(N, false); cout << fbt.get_num_of_the_distance(X, D) << endl; }; int T; cin >> T; while (T--) solve(); } int main() { ABC_321_E(); }