けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 359 G - Sum of Tree Distance (3D, 黄色, 600 点)

いろんな解法あり。

問題概要

頂点数  N の無向木が与えられる。各頂点  v には色  A_{v} が塗られている。

このグラフにおいて、 A_{u} = A_{v} であるような頂点組  (u, v) ( u \lt v) についての、2 点  u, v 間の距離の総和を求めよ。

制約

  •  2 \le N \le 2 \times 10^{5}

考えたこと

マージテクで考えた。例によって頂点  0 を根とする木 DP を考える。


dp[v][c] ← 頂点  v を根とする部分木について、色が  c であるような各頂点についての深さ (基準は頂点 0) の総和と、色が  c であるような頂点の個数の組


ここで、dp[v] の部分を連想配列で実装することにより、木 DP の更新の部分にマージテクを適用できる。連想配列として map を用いると、木 DP 全体の計算量は  O(N (\log N)^{2}) となる。

求める総和

 A_{u} = A_{v} であるような頂点組  (u, v) ( u \lt v) についての、2 点  u, v 間の距離の総和を求めるためには、上記の DP の過程において、部分木と部分木とをマージするときに加算するとやりやすい。

頂点 0 (根) からの深さが  d であるような頂点  v の部分木であって、上記の DP 値が  (s_{1}, m_{1}),  (s_{2}, m_{2}) であるような部分木をマージする場面を考える。

このとき、頂点  v をまたぐようなパスについての、色が  c 同士の頂点対の距離の総和は次のように計算できる。

 s_{1}m_{2} + s_{2}m_{1} - 2d m_{1} m_{2}

この値を逐次足していけば良い。

コード

#include <bits/stdc++.h>
using namespace std;
using Graph = vector<vector<int>>;
using pll = pair<long long, long long>;

int main() {
    int N;
    cin >> N;
    Graph G(N);
    for (int i = 0; i < N - 1; ++i) {
        int u, v; cin >> u >> v; --u, --v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    vector<int> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i], --A[i];
    
    long long ans = 0;
    auto rec = [&](auto rec, int v, int p, int d) -> map<int, pll> {
        map<int, pll> res;
        res[A[v]] = pll(d, 1);  // (sum of depth, num)
        for (auto ch : G[v]) {
            if (ch == p) continue;
            auto sub = rec(rec, ch, v, d + 1);
            if (res.size() < sub.size()) swap(res, sub);
            for (auto [v, val] : sub) {
                ans += res[v].first * val.second + val.first * res[v].second
                - res[v].second * val.second * d * 2;
                res[v].first += val.first;
                res[v].second += val.second;
            }
        }
        return res;
    };
    rec(rec, 0, -1, 0);
    cout << ans << endl;
}