けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AISing Programming Contest 2019 E - Attack to a Tree (600 点)

やっと会えたね、二乗の木DPたん...

問題へのリンク

問題概要

 N 頂点のツリーがあって、各頂点には値  A_i が割り振られている。今ツリーのエッジを何本か取り除いて何個かの連結成分に分けたとき

  • 連結成分内に含まれる全ノードの頂点の重みの和が負の値
  • 連結成分内に含まれる全ノードの頂点の重みが正の値

のいずれかを満たすようにしたい。そのような状態を実現するために取り除くエッジの本数の最小値を求めよ。

制約

  •  1 \le N \le 5000

考えたこと

とりあえず見るからにツリー DP。

  • dp[ v ][ num ][ 0 ] := v を根とする熱付き木において、num 本のエッジの切り方を考えたときに、v を含まない連結成分はすべて条件を満たすようにしたときの、v を含む連結成分内の重みの総和としてとりうる最小値

  • dp[ v ][ num ][ 1 ] := v を根とする熱付き木において、num 本のエッジの切り方を考えたときに、v を含まない連結成分はすべて条件を満たすようにして、かつ、v を含む連結成分内の頂点の重みはすべて正になるようにしたときの、v を含む連結成分内の重みの総和としてとりうる最小値 (実際はできるかどうかの判定だけでいい)

とすればできそう。が、このままだと  O(N^{3}) かかってしまうように見える。というのも、各頂点  v と各個数 num に足して、num 回のループを回すイメージなので。

が、実はなんと  O(N^{2}) になるというのが二乗の木DP

#include <iostream>
#include <vector>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }
template<class T1, class T2> ostream& operator << (ostream &s, pair<T1,T2> P)
{ return s << '<' << P.first << ", " << P.second << '>'; }
const long long INF = 1LL<<60;
const int MAX = 5100;

int N;
vector<long long> A;
vector<vector<int> > G;

int num[MAX]; // v の部分木に何個の頂点が含まれるか
long long dp[MAX][MAX][2]; // vの部分木、vを含めず何個か、(全部正かどうか)
long long sdp[MAX][2]; // 頂点 v の子ノード ch たちに足してナップサック

void rec(int v, int p) {
    num[v] = 1;
    for (auto ch : G[v]) {
        if (ch == p) continue;
        rec(ch, v);
        num[v] += num[ch];
    }
    
    for (int j = 0; j <= num[v]; ++j) sdp[j][0] = sdp[j][1] = INF;
    
    sdp[0][0] = A[v];
    if (A[v] > 0) sdp[0][1] = A[v];
    
    int curnum = 0; // 現在までに見た部分木のノード数の和 (これを管理しないと TLE する)
    for (auto ch : G[v]) {
        if (ch == p) continue;
        for (int j = curnum; j >= 0; --j) {
            long long tmp0 = sdp[j][0], tmp1 = sdp[j][1];
            sdp[j][0] = sdp[j][1] = INF;
            for (int k = 0; k <= num[ch]; ++k) {
                chmin(sdp[j+k][0], tmp0 + dp[ch][k][0]);
                if (dp[ch][k][0] < 0 || dp[ch][k][1] < INF/2) chmin(sdp[j+k+1][0], tmp0);
                if (A[v] > 0) {
                    chmin(sdp[j+k][1], tmp1 + dp[ch][k][1]);
                    if (dp[ch][k][0] < 0 || dp[ch][k][1] < INF/2) chmin(sdp[j+k+1][1], tmp1);
                }
            }
        }
        curnum += num[ch];
    }
    for (int j = 0; j <= num[v]; ++j) {
        dp[v][j][0] = sdp[j][0];
        dp[v][j][1] = sdp[j][1];
    }
}

int main() {
    while (cin >> N) {
        A.resize(N);
        for (int i = 0; i < N; ++i) cin >> A[i];
        G.assign(N, vector<int>());
        for (int i = 0; i < N-1; ++i) {
            int u, v; cin >> u >> v; --u, --v;
            G[u].push_back(v);
            G[v].push_back(u);
        }
        for (int i = 0; i < MAX; ++i) for (int j = 0; j < MAX; ++j)
            dp[i][j][0] = dp[i][j][1] = INF;
        rec(0, -1);
        int res = N;
        for (int i = 0; i <= N; ++i) {
            if (dp[0][i][0] < 0) chmin(res, i);
            if (dp[0][i][1] < INF/2) chmin(res, i);
        }
        cout << res << endl;
    }
}