けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces Round #586 E. Tourism (R2200)

結構好き...だけど、完全既出だったらしい

問題へのリンク

問題概要

 N 頂点  M 辺の連結な単純無向グラフが与えられる。各頂点  v には重み  w_v が付いている。頂点  s を始点としたウォークであって、ウォーク上のどの辺  (u, v) に対してもその直後が  (v, u) ではない (直前に通った辺をそのまま引き返す移動をしない) ようなもののうち、ウォーク上の頂点の重みの総和の最大値を求めよ。

ただし、一度通った頂点を二度以上通ってもよいが、重みが加算されるのは一度のみである。また一度通った辺を二度以上通ってもよいが、前述の通り、直前に通った辺をそのまま引き返す移動は禁じられる。

制約

  •  2 \le N \le 10^{5}
  •  1 \le M \le 10^{5}

考えたこと

面白そうだと思った!!!!!

まずは具体的なグラフをもとに考察することにした。全部の頂点を回収できないのはどんな場合かを見極めるために。

まず一直線につながったパスグラフでは、 s が内側にあると、片方の方向にしか行けないことがわかった。一度左に移動したら、もう二度と右方向へ行くことはできない。

似たことはツリーでもいえる。ツリー上の  s を根としたとき、根から特定の葉までのパスのうち、頂点の重みの総和が最大なものが答えとなる。

ではサイクルを含むとどうなるか??
サイクルを含むと事情は変わり、 s からサイクルへと足を伸ばし、ぐるっとまわったうえで  s の方に戻ってくることができる。このことをさらに深く追求すると

  • グラフの「ヒゲ」を切り落とした範囲内はほぼ自由に行き来できる
  • グラフの「ヒゲ」を切り落とした範囲内を一点  S に潰すとツリーになる
  • そのツリー上で、 S を根として各葉までのパスのうち、重みが最大のところまで移動して終了するのが最適

ということがわかる。ヒゲを切り落とすのはまさに queue を用いた後退解析のような実装がピッタリである。queue を用いてヒゲを切り落とすサイクル検出方法については以下の記事に書いた:

qiita.com

#include <iostream>
#include <vector>
#include <queue>
using namespace std;

int N, M, S;
vector<long long> w;
using Graph = vector<vector<int> >;
Graph G;

void dfs(vector<int> vs, vector<long long> &dist) {
    for (auto v : vs) {
        for (auto nv : G[v]) {
            if (dist[nv] != -1) continue;
            dist[nv] = dist[v] + w[nv];
            dfs({nv}, dist);
        }
    }
}

long long solve() {
    vector<int> deg(N, 0);
    for (int v = 0; v < N; ++v) {
        for (auto nv : G[v]) {
            ++deg[v];
            ++deg[nv];
        }
    }
    for (int v = 0; v < N; ++v) deg[v] /= 2;

    queue<int> que;
    for (int v = 0; v < N; ++v) if (v != S && deg[v] == 1) que.push(v);

    vector<bool> fucked(N, false);
    while (!que.empty()) {
        int v = que.front(); que.pop();
        fucked[v] = true;
        for (auto nv : G[v]) {
            --deg[nv];
            if (nv != S && deg[nv] == 1) que.push(nv);
        }
    }

    long long res = 0;
    vector<int> vs;
    vector<long long> dist(N, -1);
    for (int v = 0; v < N; ++v) if (!fucked[v]) {
            res += w[v];
            vs.push_back(v);
            dist[v] = 0;
        }

    dfs(vs, dist);
    long long ma = 0;
    for (int v = 0; v < N; ++v) ma = max(ma, dist[v]);
    return res + ma;
}


int main() {
    cin >> N >> M;
    w.resize(N);
    for (int i = 0; i < N; ++i) cin >> w[i];
    G.assign(N, vector<int>());
    for (int i = 0; i < M; ++i) {
        int u, v; cin >> u >> v; --u, --v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    cin >> S; --S;

    cout << solve() << endl;
}