けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 301 E - Pac-Takahashi (青色, 475 点)

前処理して頂点数を減らしたグラフ上で TSP!!! ICPC ではすごくよく見るパターンですね!

問題概要

 H \times W サイズのグリッドがある。各マスは

  • 壁マス:'#'
  • 通路マス:'.'
  • お菓子マス:'o' (18 個以内であることが保証される)
  • スタートマス:'S'
  • ゴールマス:'G'

のいずれかである。

高橋君がスタートマスにいるので、壁以外のマスのみを通って、 T 回以内の移動でゴールマスに行きたい。

それが可能な場合、なるべく多くのお菓子マスを通りたい (同じマスを 2 回以上通っても 1 回とカウント)。 T 回以内の移動で通れるお菓子マスの個数の最大値を求めよ。

制約

  •  1 \le H, W \le 300
  •  1 \le T \le 2 \times 10^{6}

考えたこと

さて、「お菓子マスが 18 個以下」という制約がいかにもヒントである!!! この制約は「bit DP してください」と我々に訴えかけている。

しかも、実行制約時間が 5 sec である。制約だけでなく、実行制約時間も結構ヒントになることがあるので、見るのもいいと思う! 2 sec よりも長めの実行制約時間が設定されているときは、指数時間アルゴリズムだったり、実装が重かったり、重いデータ構造を使ったりすることが想定される傾向がある。この点からも、bit DP で解けそうだという確信が高まる。

スタート・ゴール・お菓子のみでグラフを作る

これは ICPC では本当によく見かける手筋である。グリッドマスは  HW 個あるのだが、本質的には、スタート・ゴール・お菓子のみ考えれば十分だ。

そこで、スタート・ゴール・お菓子のみでグラフを作ることにする。このグラフの頂点数を  N とすると、 N \le 20 となるので、とても頂点数の小さいグラフと言える。

グラフの各辺には、重みとして「その頂点に対応するマス間の最短距離」を乗せることにする。これは、各スタート・ゴール・お菓子マスを始点とした BFS などで求められる。

このグラフを作ってからは、いわゆる TSP (巡回セールスマン問題) のような問題になる。

  • TSP は「すべての頂点を回ってゴールにつくまでの最短時間を求める」
  • 今回は「 T 秒以内にできるだけ多くの頂点を回ってゴールへつくようにする」

ということで微妙に違うようだけど、ほとんど一緒だ。

bit DP

TSP を解く bit DP では、次の配列 dp を考える。今回の問題も同じ配列 dp を考えて解くことができる。


dp[S][v] ← すでに訪れた頂点集合が  S であり、最後に訪れた頂点が  v であるような状態に至るまでの最短時間


この配列 dp が求められたならば、今回の問題の答えは次のように求められる。

dp[S][g] <= T であるような集合  S の要素数の最大値を求める」

ここで、スタートマスを表す頂点番号を  s とし、ゴールマスを表す頂点番号を  g とする。なお、TSP そのものであれば、dp[(1 << N) - 1][g] を答えることになる。

これ以降、DP の遷移式を具体的に考えていく。集合  Sビットで記述することにする。

DP の初期条件

まず、初期条件を考えよう。スタート頂点  s のみを訪れている状態の集合は 1 << s と書けるので、初期条件は次のように書ける。

  • dp[(1 << s)][s] = 0

DP の遷移式

次に DP の遷移式を考える。すでに訪れた頂点集合が  S であり、最後に訪れた頂点が  v であるような状態 (その状態にするまでの最短時間は dp[S][v]) に対して、次に頂点  v_2 に行こうとする場面を考える。

このとき、頂点  v_2 へ移動すると

  • 頂点集合は  S から、 S \cup { v_{2}} へと変更される
  • 最後に訪れた頂点は  v_2 となる。

なお、集合  S を表すビットを bit とすると、新たな頂点集合を表すビットは bit | (1 << v2) と書ける。よって、遷移式は次のように書ける。ここで、頂点  v, v_{2} 間の距離を G[v][v2] と表すことにする。また、ここでは「配る DP」の形で書いている。

dp[bit | (1 << v2)][v2] = min(dp[bit | (1 << v2)][v2], dp[bit][v] + G[v][v2]);

以上の DP を実装すれば AC になる。最後に計算量を見積もる ( N はスタートマス・ゴールマス・お菓子マスの個数)。

  • グリッドから、スタート・ゴール・お菓子のみのグラフを作る:計算量は  O(NHW)
  • bit DP をして解く:計算量は  O(2^{N} N^{2} )

よって、全体の計算量は  O(NHW + 2^{N} N^{2}) となる。

コード

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int,int>;
const vector<int> dx = {1, 0, -1, 0};
const vector<int> dy = {0, 1, 0, -1};

// 前処理:マス (si, sj) をスタートとして BFS
vector<vector<int>> bfs(const vector<string> &A, int sx, int sy) {
    int H = A.size(), W = A[0].size();
    vector<vector<int>> dist(H, vector<int>(W, -1));
    queue<pint> que;
    
    dist[sx][sy] = 0;
    que.push({sx, sy});
    while (!que.empty()) {
        auto [x, y] = que.front();
        que.pop();
        for (int dir = 0; dir < 4; ++dir) {
            int x2 = x + dx[dir], y2 = y + dy[dir];
            if (x2 < 0 || x2 >= H || y2 < 0 || y2 >= W) continue;
            if (A[x2][y2] == '#') continue;
            if (dist[x2][y2] == -1) {
                dist[x2][y2] = dist[x][y] + 1;
                que.push({x2, y2});
            }
        }
    }
    return dist;
}

const long long INF = 1LL<<60;
int main() {
    long long H, W, T;
    cin >> H >> W >> T;
    vector<string> A(H);
    for (int i = 0; i < H; ++i) cin >> A[i];
    
    // S, G, お菓子マスを抽出して、距離グラフを作る
    vector<pint> nodes;
    int start = -1, goal = -1;
    for (int i = 0; i < H; ++i) {
        for (int j = 0; j < W; ++j) {
            if (A[i][j] == 'S') start = nodes.size(), nodes.emplace_back(i, j);
            else if (A[i][j] == 'G') goal = nodes.size(), nodes.emplace_back(i, j);
            else if (A[i][j] == 'o') nodes.emplace_back(i, j);
        }
    }
    int N = nodes.size();
    vector<vector<long long>> G(N, vector<long long>(N, INF));
    for (int i = 0; i < N; ++i) {
        auto [x, y] = nodes[i];
        const auto &dist = bfs(A, x, y);
        for (int j = 0; j < N; ++j) {
            auto [x2, y2] = nodes[j];
            if (dist[x2][y2] != -1) G[i][j] = dist[x2][y2];
        }
    }
    
    // bit DP: dp[S][v] := 頂点集合 S をすべて通って、最後にいる頂点が v であるときの、最小コスト
    vector<vector<long long>> dp(1<<N, vector<long long>(N, INF));
    dp[1<<start][start] = 0;
    for (int bit = 0; bit < (1<<N); ++bit) {
        for (int v = 0; v < N; ++v) {
            for (int v2 = 0; v2 < N; ++v2) {
                int nbit = bit | (1<<v2);
                dp[nbit][v2] = min(dp[nbit][v2], dp[bit][v] + G[v][v2]);
            }
        }
    }
    
    // dp[bit][goal] <= T となる bit について、最多お菓子数を求める
    int res = -1;
    for (int bit = 0; bit < (1<<N); ++bit) {
        if (dp[bit][goal] <= T) res = max(res, __builtin_popcount(bit) - 2);
    }
    cout << res << endl;
}