けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 259 F - Select Edges (青色, 500 点)

木 DP のいい感じの練習問題だった!

問題概要

頂点数  N の重み付き木が与えられる (重みは負のこともある)。

この木の辺の部分集合であって、各頂点  i に接続する辺の本数が  d_{i} 本以下であるようなものに対して、辺の重みの総和の最大値を求めよ。

制約

  •  2 \le N \le 3 \times 10^{5}
  •  0 \le d_{i} \le (頂点  i の次数)
  •  -10^{9} \le (各辺の重み)  \le 10^{9}

考えたこと

最初は Kruskal 法のようなノリで Greedy にできないかと考えたけど無理そうだった。

木 DP で解けそうだと思った。頂点 0 を根として根付き木を作る。このとき、

  • dpfull[v] ← 頂点  v を根とした根付き木において、頂点  v に接続する辺が  d_{v} 本以下となるように辺を選んだときの、辺の重みの総和の最大値
  • dprem[v] ← 頂点  v を根とした根付き木において、頂点  v に接続する辺が  d_{v} 本未満となるように辺を選んだときの、辺の重みの総和の最大値

として木 DP した。

dpfull[v] や、dprem[v] を求める際には、頂点  v の子頂点のうち、どの子頂点へと繋がる辺を優先的に選べば良いかを Greedy に求めた。

その部分でソート処理を行う必要があるため、全体の計算量は  O(N \log N) となる。

コード

#include <bits/stdc++.h>
using namespace std;
using Edge = pair<int, long long>;
using Graph = vector<vector<Edge>>;
const long long INF = 1LL<<60;

int main() {
    int N;
    cin >> N;
    vector<int> d(N);
    for (int i = 0; i < N; ++i) cin >> d[i];
    Graph G(N);
    for (int i = 0; i < N-1; ++i) {
        long long u, v, w;
        cin >> u >> v >> w;
        --u, --v;
        G[u].emplace_back(v, w);
        G[v].emplace_back(u, w);
    }
    
    // 木 DP
    vector<long long> dpfull(N, -INF), dprem(N, -INF);
    auto rec = [&](auto self, const Graph &G, int v, int p) -> void {
        long long base = 0;
        vector<long long> diff;  // その子頂点への辺を選ぶと、選ばない場合と比べてどこまで伸びるか
        for (auto e : G[v]) {
            if (e.first == p) continue;
            self(self, G, e.first, v);
            base += dpfull[e.first];
            diff.push_back(e.second + dprem[e.first] - dpfull[e.first]);
        }
        
        // その子頂点への辺を選ぶことでスコアが伸びるところから Greedy に取っていく
        long long fulladd = 0, remadd = 0;
        sort(diff.begin(), diff.end(), greater<long long>());
        for (int i = 0; i < min(d[v], (int)diff.size()) && diff[i] >= 0; ++i) {
            fulladd += diff[i];
            if (i < d[v] - 1) remadd += diff[i];
        }
        dpfull[v] = base + fulladd;
        if (d[v]) dprem[v] = base + remadd;
    };
    rec(rec, G, 0, -1);
    
    cout << dpfull[0] << endl;
}