けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Educational Codeforces Round 2 E. Lomsat gelral (R2300)

マージテク童貞を卒業した!!!

問題へのリンク

問題概要

 N 頂点の根付き木が与えられる (根の番号は 1)。また各頂点  v には色  c_{v} が塗られている。色は整数値で表される。各頂点  v について、以下の問いに答えよ。

  • その頂点を根とした部分木を考える
  • その部分木で「最頻出の色」をすべて求める
  • その最頻出の色の表す整数値の合計値を求めよ。

制約

  •  1 \le N \le 10^{5}

考えたこと

一見するとナイーブな木 DP で  O(N^{2}) かかるように思えてしまうけど、マージテクによって計算量が下がるという教育的問題!!!

各部分木について

  • どの色が何個あるかを表す counter (map で実装)
  • 最頻値がいくつか mean
  • 最頻値な色の総和 sum

を求めておく。木 DP でこれらの情報をマージするときに、

  • counter の要素数が小さい方から大きい方へとマージする

という風にするだけで、計算量が下がるのだ!!!これで全体を通して要素が統合される回数が  O(N\log{N}) となる。map を使っているので全体の計算量は  O(N(\log{N})^{2}) となる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

using Graph = vector<vector<int>>;
int N;
Graph G;
vector<int> col;
vector<long long> ans;

void rec(int v, int p, map<int, int> &counter, int &mean, long long &sum) {
    counter.clear();
    counter[col[v]]++;
    mean = 1;
    sum = col[v];
    for (auto e : G[v]) {
        if (e == p) continue;
        map<int, int> ch_counter;
        int ch_mean;
        long long ch_sum;
        rec(e, v, ch_counter, ch_mean, ch_sum);
        if (counter.size() < ch_counter.size()) {
            if (chmax(mean, ch_mean)) sum = ch_sum;
            else if (mean == ch_mean) sum += ch_sum;
            swap(counter, ch_counter);
            swap(mean, ch_mean);
            swap(sum, ch_sum);
        }
        for (auto it : ch_counter) {
            counter[it.first] += it.second;
            if (chmax(mean, counter[it.first])) sum = it.first;
            else if (mean == counter[it.first]) sum += it.first;
        }
    }
    ans[v] = sum;
}

int main() {
    scanf("%d", &N);
    col.resize(N);
    for (int i = 0; i < N; ++i) scanf("%d", &col[i]);
    ans.assign(N, 0);
    G.assign(N, vector<int>());
    for (int i = 0; i < N-1; ++i) {
        int u, v;
        scanf("%d %d", &u, &v);
        --u, --v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    map<int, int> counter;
    int mean;
    long long sum;
    rec(0, -1, counter, mean, sum);
    for (int i = 0; i < N; ++i) {
        if (i) printf(" ");
        printf("%lld", ans[i]);
    }
    printf("\n");
}