けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces 558 DIV2 D. Mysterious Code (R2200)

部分文字列の遷移は愚直に求めた。。。

問題へのリンク

問題概要

長さ  N の文字列 c と、短い文字列 s, t が与えられ、'a'〜'z' と '?' からなっている。'?' を埋める方法のうち、

  • c の連続する部分文字列として s を含む箇所の個数から
  • c の連続する部分文字列として t を含む箇所の個数を引いたもの

の最大値を求めよ。

制約

  •  1 \le N \le 1000
  •  1 \le |s|, |t| \le 50

考えたこと

いわゆる部分文字列の遷移状態を管理するタイプの DP。禁止文字列を含まない文字列の数え上げ問題などでもおなじみ。基本的には

  • dp[ i ][ j ][ k ] := c の最初 i 文字分を見て、c の直近 j 文字が s の左から j 文字と一致して、c の直近 k 文字が t の左から k 文字と一致する (それぞれ最長一致) 場合についての、スコアの最大値

とすればよい。が、この手のやつで注意すべきは、例えば s = "abacabac" だったとき、c の直近 5 文字が "abaca" だったとして、

  • c の次の文字が "b" だったら、状態は 5 から 6 に変化するのは明らかだが
  • c の次の文字が "b" 以外であっても状態が 0 になるとは限らず、'a' だったら 1、"c" だったら 4 になったりする

ということ。この遷移状態を先に求める必要がある。これは trie木などでできるはずだけど、持ってないので愚直にやった。50 文字くらいなら充分間に合う。

#include <iostream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cstring>
#include <string>
#include <vector>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <numeric>
#include <utility>
#include <iomanip>
#include <algorithm>
#include <functional>
#include <unordered_map>
using namespace std;

template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

template<class T> vector<T> make_vec(size_t a) { return vector<T>(a); }
template<class T, class... Ts> auto make_vec(size_t a, Ts... ts) {
  return vector<decltype(make_vec<T>(ts...))>(a, make_vec<T>(ts...));
}
template<class T, class V>
typename enable_if<is_class<T>::value == 0>::type fill(T &t, const V &v) {
    t = v;
}
template<class T, class V>
typename enable_if<is_class<T>::value != 0>::type fill(T &t, const V &v){
    for (auto &e : t) fill(e, v);
}

#define COUT(x) cout << #x << " = " << (x) << " (L" << __LINE__ << ")" << endl
template<class T1, class T2> ostream& operator << (ostream &s, pair<T1,T2> P)
{ return s << '<' << P.first << ", " << P.second << '>'; }
template<class T> ostream& operator << (ostream &s, vector<T> P)
{ for (int i = 0; i < P.size(); ++i) { if (i > 0) { s << " "; } s << P[i]; } return s; }
template<class T> ostream& operator << (ostream &s, vector<vector<T> > P)
{ for (int i = 0; i < P.size(); ++i) { s << endl << P[i]; } return s << endl; }

#define EACH(i, s) for (__typeof__((s).begin()) i = (s).begin(); i != (s).end(); ++i)
template<class T> ostream& operator << (ostream &s, set<T> P)
{ EACH(it, P) { s << "<" << *it << "> "; } return s << endl; }
template<class T1, class T2> ostream& operator << (ostream &s, map<T1,T2> P)
{ EACH(it, P) { s << "<" << it->first << "->" << it->second << "> "; } return s << endl; }

const int INF = 1<<29;
string c, s, t;

vector<vector<int> > trans(const string &S) {
    int N = (int)S.size();
    vector<vector<int> > res(N+1, vector<int>(26, 0));
    for (int n = 0; n <= N; ++n) {
        string SS = S.substr(0, n);
        for (int i = 0; i < 26; ++i) {
            char c = 'a' + i;
            string SSS = SS + c;
            int tmp = 0;
            for (int j = 0; j <= SSS.size(); ++j) {
                if (S.substr(0, j) == SSS.substr((int)SSS.size()-j, j))
                    chmax(tmp, j);
            }
            res[n][i] = tmp;
        }
    }
    return res;
}

int solve() {
    int N = c.size();
    auto dp = make_vec<int>(N+1, (int)s.size()+1, (int)t.size()+1);
    fill(dp, -INF);
    auto ts = trans(s);
    auto tt = trans(t);
    
    dp[0][0][0] = 0;
    for (int i = 0; i < N; ++i) {
        for (int j = 0; j <= s.size(); ++j) {
            for (int k = 0; k <= t.size(); ++k) {
                if (c[i] != '*') {
                    int ni = (int)(c[i] - 'a');
                    int nj = ts[j][ni];
                    int nk = tt[k][ni];
                    int add = 0;
                    if (nj == s.size()) ++add;
                    if (nk == t.size()) --add;
                    chmax(dp[i+1][nj][nk], dp[i][j][k] + add);
                }
                else {
                    for (int ni = 0; ni < 26; ++ni) {
                        int nj = ts[j][ni];
                        int nk = tt[k][ni];
                        int add = 0;
                        if (nj == s.size()) ++add;
                        if (nk == t.size()) --add;
                        chmax(dp[i+1][nj][nk], dp[i][j][k] + add);
                    }
                }
            }
        }
    }
    
    int res = -INF;
    for (int j = 0; j <= s.size(); ++j) {
        for (int k = 0; k <= t.size(); ++k) {
            chmax(res, dp[N][j][k]);
        }
    }
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    
    while (cin >> c >> s >> t) {
        cout << solve() << endl;
    }
}