けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Educational Codeforces Round 83 G. Autocompletion (R2600)

頑張って DFS だけで通した!!!

問題へのリンク

問題概要

頂点数  N+1 の Trie 木と、そのうちの  K 個の頂点集合  S が与えられる。 S の各頂点  v について、トライ木の根から出発して、以下の操作によって到達するまでの最小コストを求めよ。

  • トライ木の辺を 1 本先に進める (親方向には進めない): コスト 1
  •  S のうち今いる頂点の子孫 (今いる頂点を含む) のみを取り出した集合を  S' として、 v S' に含まれるならば一気に移動する: コストは、 S' における  v の辞書順順位 (1-indexed)

制約

  •  1 \le N \le 10^{6}

考えたこと

とりあえず DP することを考える。dp[ v ] を v にいたるまでの最小コストとする。dp[ v ] を求めるにあたり、以下の 2 つの場合に分けられる

  • 1 個上の親 p から 1 歩進んで到達する
  • いずれかの先祖 w から、一気に到達する

このうちの前者については簡単で、chmin(dp[ v ], dp[ p ] + 1)。
後者についてちゃんと考えることにする。

DFS

ここで、以下の変数を導入しよう。

  • iter[ v ] := トライ木を辞書順に DFS していったとき、頂点 v に到達した時点で、S に含まれる頂点を何個訪れたか? (ただし v 自身が S に含まれるときは、この個数に v を含めないこととする)

このとき、w から v へと移動した場合の DP 遷移は次のように表せるのだ。

  • chmin(dp[ v ], dp[ w ] + (iter[ v ] - iter[ w ] + 1))

ここで、iter[ v ] - iter[ w ] + 1 というのが、移動コストを表す。この遷移式をよく眺めて、集める DP にするとこうなる!

  • dp[ v ] = iter[ v ] + min(w: v の先祖) (dp[ w ] - iter[ w ] + 1)

こうしてみると、min の部分は累積和をとるような要領で管理すれば、DFS 一発で全頂点に対する DP 値が求められることがわかった!!!

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

using Edge = pair<int, char>;
using Graph = vector<vector<Edge>>;
const int INF = 1<<29;

void rec(const Graph &G, const vector<int> &isS, vector<int> &dp,
         int v, int &iter, int dp_of_parent, int min_dp_iter) {
    chmin(dp[v], dp_of_parent + 1);
    if (isS[v]) chmin(dp[v], iter + min_dp_iter);
    chmin(min_dp_iter, dp[v] - iter + 1);
    if (isS[v]) ++iter;
    for (auto e : G[v])
        rec(G, isS, dp, e.first, iter, dp[v], min_dp_iter);
}

int main() {
    // input
    int N; scanf("%d", &N);
    Graph G(N+1);
    for (int i = 0; i < N; ++i) {
        int v; char c;
        scanf("%d %c", &v, &c);
        G[v].emplace_back(i+1, c);
    }
    for (int i = 0; i <= N; ++i)
        sort(G[i].begin(), G[i].end(),
             [&](Edge x, Edge y) { return x.second < y.second; });
    int K; scanf("%d", &K);
    vector<int> S(K), isS(N+1, false);
    for (int i = 0; i < K; ++i) scanf("%d", &S[i]), isS[S[i]] = true;

    // solve
    int iter = 1;
    vector<int> dp(N+1, INF);
    rec(G, isS, dp, 0, iter, -1, 0);
    vector<int> res(K, -1);
    for (int i = 0; i < K; ++i) { 
        if (i) printf(" ");
        printf("%d", dp[S[i]]);
    }
    printf("\n");
}