けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

競プロ典型 90 問 003 - Longest Circular Road(★4)

木の直径を求めよ、という問題。直径を考えると解ける問題は高難易度でもお馴染みですね。

類題とか

drken1215.hatenablog.com

問題概要

頂点数  N の木が与えられます。

この木に新たに辺を 1 本追加すると、閉路が 1 つできます。

このようにして形成される閉路に含まれる辺の本数として考えられる最大値を求めてください。

制約

  •  3 \le N \le 10^{5}

考えたこと

まずは木に辺を付け加えることで「閉路」が形成されるという感覚を掴んでみましょう。

まず木とは「閉路を持たない (連結な) グラフ」のことです。下図のように、木に辺を 1 本加えると、閉路が 1 つ形成されます。

f:id:drken1215:20210612174958p:plain

閉路の長さとは

そして、頂点  u, v に辺を結ぶことでできる閉路の長さは、


パス  u- v の長さ + 1


になります。上図でいえば、頂点 3 と頂点 6 を終点に持つパス 3-6 の長さが 4 ですので、辺 (3, 6) を追加することでできる閉路の長さは 5 になります。

よって問題は

「木の 2 頂点  u, v を選んだときの、パス  u- v の長さの最大値を求めよ」

と言い換えられます。ちなみに、このように「グラフの 2 頂点間の距離の最大値」のことを、そのグラフの直径と言います。つまり今回の問題は「木の直径の長さを求めてください」ということになります。

木の直径

木の直径は次の方法で求められます (証明は略)。

  • 適当な頂点  u を 1 つ選ぶ
  • 頂点  u から最も遠い頂点  v を求める ( O(N))
  • 頂点  v から最も遠い頂点  w を求める ( O(N))

このとき、パス  v- w の長さが求める直径になります。それに 1 を足して出力すればよいです。

なお、木において 1 頂点  s から各頂点  v への距離を求めるのは DFS でも BFS でも求められます。下に挙げるコードでは DFS で実装しています。

計算量は  O(N) になります。

コード (C++)

ここでは stack を用いた非再帰の DFS で解きます。BFS による解法は E8 さんのコードにあります。

#include <bits/stdc++.h>
using namespace std;

// グラフを表すデータ型
using Graph = vector<vector<int>>;

// 頂点 s から DFS (ここではスタックを使う)
vector<int> dfs(const Graph &G, int s) {
    int N = G.size();
    
    // 頂点 s からの距離
    vector<int> dist(N, -1);
    dist[s] = 0;

    // スタックで DFS
    stack<int> st({s});
    while (!st.empty()) {
        int v = st.top();
        st.pop();
        for (auto nv: G[v]) {
            if (dist[nv] == -1) {
                st.push(nv);
                dist[nv] = dist[v] + 1;
            }
        }
    }

    // リターン
    return dist;
}

int main() {
    // 入力
    int N;
    cin >> N;
    Graph G(N);
    for (int i = 0; i < N-1; ++i) {
        int a, b; cin >> a >> b; --a, --b;
        G[a].push_back(b);
        G[b].push_back(a);
    }

    // 頂点 0 から DFS
    auto dist0 = dfs(G, 0);

    // 距離最大の点を求める
    int mx = -1, mv = -1;
    for (int v = 0; v < N; ++v) {
        if (mx < dist0[v]) {
            mx = dist0[v];
            mv = v;
        }
    }

    // 頂点 mv から DFS
    auto distmv = dfs(G, mv);

    // その最大値を求める
    mx = -1;
    for (int v = 0; v < N; ++v) {
        mx = max(mx, distmv[v]);
    }
    cout << mx + 1 << endl;
}

コード (Python3)

# 入力
N = int(input())
G = [[] for _ in range(N)]
for _ in range(N - 1):
    a, b = map(int, input().split())
    a, b = a - 1, b - 1  # 0-indexed に
    G[a].append(b)
    G[b].append(a)

# 頂点 s から DFS (ここではスタックを使う)
def dfs(s):
    # 頂点 s からの距離
    dist = [-1] * N
    dist[s] = 0

    # スタックで DFS
    st = [s]
    while st:
        v = st.pop()
        for nv in G[v]:
            if dist[nv] == -1:
                st.append(nv)
                dist[nv] = dist[v] + 1

    # リターン
    return dist

# 頂点 0 から
dist0 = dfs(0)
mv = max(enumerate(dist0), key=lambda x: x[1])[0]
distmv = dfs(mv)
print(max(distmv) + 1)