けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 184 E - Third Avenue (水色, 500 点)

おおむね BFS だけど、ちょっとだけ TLE に注意。。。

問題概要

 H \times W のグリッドが与えられる。"." は通路マス、"#" は壁で侵入不能マス、"S" はスタート、"G" はゴールである。さらに英小文字で表された各マス間は 1 手で自由にワープで行き来可能である。

"S" から "G" への最短路長を求めよ。

制約

  •  1 \le H, W \le 2000

考えたこと

まず愚直には、次のようにして作ったグラフ上で BFS する解法が考えられる。

  • 各マスを頂点とするグラフを考える
  • 壁以外のマス同士が隣接しているところには辺を張る
  • 同じ英小文字があるマス同士には辺を張る

しかし極端な話、マスの大半が同じ英小文字になっているようなケースだと、辺の本数が  O(H^{2}W^{2}) くらいになってしまう。これでは通せない。色々工夫が考えられる。

解法 (1):英小文字間のワープは英小文字ごとに 1 回だけにする

実際に BFS をするときには、ある英小文字のマスに初めて到着したときに、それと同じ英小文字のマス全体に 1 ステップで行き渡らせることにする。そしてその後は、その英小文字間のワープは禁止してしまえば OK。

このようにすることで計算量は  O(HW) に落ちる。

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int,int>;
vector<int> dx = {1, 0, -1, 0};
vector<int> dy = {0, 1, 0, -1};

int main() {
    int H, W;
    cin >> H >> W;
    vector<string> fi(H);
    for (int i = 0; i < H; ++i) cin >> fi[i];

    int sx, sy, tx, ty;
    vector<vector<pint>> pl(26);
    for (int i = 0; i < H; ++i) {
        for (int j = 0; j < W; ++j) {
            if (fi[i][j] == 'S') sx = i, sy = j;
            else if (fi[i][j] == 'G') tx = i, ty = j;
            else if (fi[i][j] >= 'a' && fi[i][j] <= 'z') pl[fi[i][j]-'a'].push_back(pint(i,j));
        }
    }

    vector<bool> used(26, false);
    vector<vector<int>> dist(H, vector<int>(W, -1));
    dist[sx][sy] = 0;
    queue<pint> que;
    que.push(pint(sx, sy));
    while (!que.empty()) {
        auto tmp = que.front(); que.pop();
        int x = tmp.first, y = tmp.second;
        if (fi[x][y] >= 'a' && fi[x][y] <= 'z') {
            int c = fi[x][y] - 'a';
            if (!used[c]) {
                for (auto p : pl[c]) {
                    if (dist[p.first][p.second] == -1) {
                        dist[p.first][p.second] = dist[x][y] + 1;
                        que.push(p);
                    }
                }
            }
            used[c] = true;
        }
        for (int dir = 0; dir < 4; ++dir) {
            int nx = x + dx[dir], ny = y + dy[dir];
            if (nx < 0 || nx >= H || ny < 0 || ny >= W || fi[nx][ny] == '#') continue;
            if (dist[nx][ny] == -1) {
                dist[nx][ny] = dist[x][y] + 1;
                que.push(pint(nx, ny));
            }         
        }
    }
    cout << dist[tx][ty] << endl;
}

解法 (2):英小文字ごとにスーパーノード

もともと「密なグラフに対して、スーパーノードを用意することで辺数を減らす」というテクニックはよく見る。今回もそんなふうにできる。


  • 頂点集合を次のようにする
    • 各マス
    • 各英小文字に対応するスーパーノード
  • 辺集合を次のようにする
    • 隣接するマス同士に辺を張る
    • 各英小文字のマスから、対応する文字のスーパーノードへ、長さ 1 の辺を張る
    • そのスーパーノードから、各英小文字のマスへと、長さ 0 の辺を張る

このように作り上げたグラフで "S" マスから "G" マスへの最短路長を求めれば OK。ただし解法 (1) のときとは違って、辺の長さが 0 と 1 の二種類があるので、単純な BFS では解けない。そこで

といった解法が使える!!!ダイクストラ法 ( O(HW(\log H + \log W))) でもいいんだけど、0-1 BFS ( O(HW)) なら計算量が落ちる。

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int,int>;
vector<int> dx = {1, 0, -1, 0};
vector<int> dy = {0, 1, 0, -1};

int main() {
    int H, W;
    cin >> H >> W;
    vector<string> fi(H);
    for (int i = 0; i < H; ++i) cin >> fi[i];
    
    int sx, sy, tx, ty;
    vector<vector<int>> pl(26);
    for (int i = 0; i < H; ++i) {
        for (int j = 0; j < W; ++j) {
            if (fi[i][j] == 'S') sx = i, sy = j;
            else if (fi[i][j] == 'G') tx = i, ty = j;
            else if (fi[i][j] >= 'a' && fi[i][j] <= 'z') 
                pl[fi[i][j]-'a'].push_back(i*W + j);
        }
    }

    vector<int> dist(H*W+26, -1);
    dist[sx * W + sy] = 0;
    deque<int> que;
    que.push_back(sx * W + sy);
    while (!que.empty()) {
        int id = que.front(); que.pop_front();
        if (id >= H * W) {
            int c = id - H*W;
            for (auto nid : pl[c]) {
                if (dist[nid] == -1) {
                    dist[nid] = dist[id];
                    que.push_front(nid);
                }
            }
        }
        else {
            int x = id / W, y = id % W;
            if (fi[x][y] >= 'a' && fi[x][y] <= 'z') {
                int nid = H * W + (int)(fi[x][y] - 'a');
                if (dist[nid] == -1) {
                    dist[nid] = dist[id] + 1;
                    que.push_back(nid);
                }
            }
            for (int dir = 0; dir < 4; ++dir) {
                int nx = x + dx[dir], ny = y + dy[dir];
                if (nx < 0 || nx >= H || ny < 0 || ny >= W || fi[nx][ny] == '#') continue;
                int nid = nx * W + ny;
                if (dist[nid] == -1) {
                    dist[nid] = dist[id] + 1;
                    que.push_back(nid);
                }         
            }
        }
    }
    cout << dist[tx * W + ty] << endl;
}