けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 175 F - Making Palindrome (橙色, 600 点)

こういう重たい実装を確実にこなせるように...なりたい!

問題へのリンク

問題概要

 N 個の文字列  S_{1}, \dots, S_{N} が与えられる。これらを好きな順序で好きな回数だけ concat して回文を作りたい。ただし  i 番目の文字列を使用するコストは 1 回あたり  C_{i} である。

回文を作れるかどうかを判定し、作れるならば最小コストを求めよ。

制約

  •  1 \le N \le 50
  •  1 \le |S_{i}| \le 20

考えたこと

最初なんもわからんかった。ただちょっと気になっていたのは

  • 回文の真ん中にある文字列は切断されているはず
  • その両端で上手に半分全列挙的なことができないか...?

といったことを考えていた。ただ  N \le 50 とかいう制約は半分全列挙するのは現実的じゃないし、そもそも同じ文字列を何度でも使って良いというのでは厳しい。

でも、回文の真ん中の文字列を考えるのは悪くなかったみたいで、ここからもう少し考察を深めればよかった。。。

探索することを考えてみる

この問題で探索的な解法をとろうと思うと


  • まず回文の真ん中をなす文字列の (種類, 切れ目) を全探索する
  • その後
    • 左にはみ出していたら、右側に付け加える文字列 (種類) を全探索する
    • 右にはみ出していたら、左側に付け加える文字列 (種類) を全探索する
  • このような方法を全探索して、「はみ出し」がなくなるパターンを抽出していく

という感じの解法が考えられる。一見すると指数時間なのだが、考慮すべき状態数は以下の通りで、実はかなり少ない!!!

  • 左側のはみ出しに関しては、各文字列の prefix のみ考えれば良い
  • 右側のはみ出しに関しては、各文字列の suffix のみ考えれば良い

ということで、ありうる状態数は  O(NS) とかなのだ (ここでは  S を各文字列の文字数の最大値とした)。これらをノードとしたグラフ上で Dijkstra 法を回せば OK!!!

計算量は、

  • 頂点数  O(NS)
  • 辺数  O(N^{2}S)

というわけで  O(N^{2}S (\log{N} + \log{S})) とかになる。実装上の簡便さのために、各ノードを

  • pair<string,int>:(はみ出した文字列, 左右のどちらか)

で管理して DP 配列を map<pair<string,int>, long long> 型で管理するなどしても、十分間に合った。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }
const long long INF = 1LL<<60;
using Node = pair<string, int>; // left: 0, right: 1

long long solve() {
    int N; 
    cin >> N;
    vector<string> S(N);
    vector<long long> C(N);
    for (int i = 0; i < N; ++i) cin >> S[i] >> C[i];

    long long res = INF;
    map<Node, long long> dp;
    priority_queue<pair<long long, Node>,
        vector<pair<long long, Node>>, greater<pair<long long, Node>>> que;

    auto func = [&](const string &left, const string &right) {
        int i = (int)left.size()-1, j = 0;
        while (true) {
            if (i == -1 || j == (int)right.size()) break;
            if (left[i] != right[j]) return Node("", -1);
            --i, ++j;
        }
        if (left.size() >= right.size()) {
            string res = left.substr(0, (int)left.size() - (int)right.size());
            return Node(res, 0);
        }
        else {
            string res = right.substr((int)left.size());
            return Node(res, 1);
        }
    };
    auto relax = [&](const string &left, const string &right, long long cost) {
        auto v = func(left, right);
        if (v.second == -1) return;
        if (v.first == "") chmin(res, cost);
        else {
            que.push(make_pair(cost, v));
            if (dp.count(v)) chmin(dp[v], cost);
            else dp[v] = cost;
        }
    };
    
    for (int i = 0; i < N; ++i) {
        for (int j = 0; j <= S[i].size(); ++j) {
            string left = S[i].substr(0, j), right = S[i].substr(j);
            relax(left, right, C[i]);
        }
        for (int j = 0; j < S[i].size(); ++j) {
            string left = S[i].substr(0, j), right = S[i].substr(j+1);
            relax(left, right, C[i]);
        }
    }
    while (!que.empty()) {
        auto tmp = que.top(); que.pop();
        long long dist = tmp.first;
        string cur = tmp.second.first;
        int type = tmp.second.second;
        if (dist > dp[Node(cur, type)]) continue;
        for (int i = 0; i < N; ++i) {
            if (type == 0) relax(cur, S[i], dist + C[i]);
            else relax(S[i], cur, dist + C[i]);
        }
    }
    return res < INF ? res : -1;
}

int main() {
    cout << solve() << endl;
}