けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 2630 Dictionary (JAG 夏合宿 2014 day2-B)

最初は「え!? i 行目の文字列情報がないと i+1 行目以降のことを考えられなくない?」となってしまって、途方に暮れてしまった

問題へのリンク

問題概要

 N 個の文字列  s_{1}, \dots, s_{N} があって、それぞれいくつかの文字は '?' で隠されている (したのは  N = 2 の例)

'?' 全体をアルファベット英小文字で埋める方法のうち、辞書順で  s_{1} \lt \dots \lt s_{N} を満たすものが何通りあるか、1000000007 で割ったあまりを求めよ。

?sum??mer
c??a??mp

制約

  •  1 \le N \le 50
  • 各文字列の長さ  \le 20

考えたこと

一見すると、一行ずつ順番に見ていくときに「前回の文字列の完全な形がないと次の文字列が何通りかとか考えられない...」と思ってしまって途方に暮れてしまった。

こういう「全部の情報がないと DP 遷移作れなさそう...」という状態から、「実はうまく状態をまとめれば、ある種の個数だけわかっていればよい」という風にもっていくのは良く見るし、ある程度パターン化できたらよさそう...

さて、2 つの文字列  S, T が辞書順であるとは、

  • S[0] < T[0] である
  • S[0] = T[0] であって、S[1:] < T[1:] である

という 2 つの場合に分けられることに注意する。後者はいかにも「0 文字目以降に関する問題」を「1 文字目以降に関する問題」へと帰着できそうな香りをもっている。そこで、こんな DP が立ちそう。

  • dp[ i ][ j ][ k ][ c ] := N 個の文字列のうち区間 [i, j) の部分について、k 文字目以降に関しては、先頭の文字を 'a' + c 以上にするようなものの場合の数

こうすると、区間 [l, r) のうちのどこまでが k 文字目が 'a'+c になっているかで場合分けして、


各 mid = i, i+1, ... j に対して、
S[ mid-1 ][ k ] = 'a'+c で、S[ mid ][ k ] > 'a'+c のとき

dp[ i ][ mid ][ k+1 ][ 0 ] × dp[ mid ][ j ][ k ][ c + 1 ] 通り


という風になる。あとは初期化に注意しつつ頑張る。実装上の注意点として、初期条件がとても扱いづらいので、メモ化再帰に押し付けるとよさそうな気がする。あと、それぞれの文字列の長さが異なるのが厄介なので末尾に 'A' ('Z' < 'a') を足して揃えることにした。

計算量は  O(N^{3}|S|C) C はアルファベットの個数。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

using mint = Fp<1000000007>;

int N;
vector<string> S;

string moji;
bool seen[55][55][30][30];
mint dp[55][55][30][30];
mint beki[50];

mint rec(int i, int j, int k, int c) {
    if (i == j) return 1;
    if (c >= moji.size()) return 0;

    // 複数行あって k が超えていたらダメと考える (狭義単調増加にできないため)
    if (k >= S[i].size()) {
        if (j - i == 1) return 1;
        else return 0;
    }
    
    // 1 行の場合
    if (j - i == 1) {
        int con = 0;
        for (int p = k+1; p < S[i].size(); ++p) if (S[i][p] == '?') ++con;
        if (S[i][k] == '?') return beki[con] * min(26, 27 - c);
        else if (S[i][k] < moji[c]) return 0;
        else return beki[con];
    }

    // メモを check
    if (seen[i][j][k][c]) return dp[i][j][k][c];

    // 場合分け
    mint res = 0;
    int mac = 0;
    for (int mid = i; mid <= j; ++mid) {
        res += rec(i, mid, k+1, 0) * rec(mid, j, k, c+1);
        if (mid == j) break;
        if (S[mid][k] != '?' && S[mid][k] != moji[c]) break;
        if (S[mid][k] == '?' && moji[c] == 'A') break;
    }
    seen[i][j][k][c] = true;
    return dp[i][j][k][c] = res;
}

int main() {
    cin >> N;
    S.resize(N);
    int ms = 0;
    for (int i = 0; i < N; ++i) cin >> S[i], ms = max(ms, (int)S[i].size());
    for (int i = 0; i < N; ++i) while (S[i].size() < ms) S[i] += 'A';
    moji = 'A';
    for (int i = 0; i < 26; ++i) moji += (char)('a' + i);
    
    memset(seen, 0, sizeof(seen));
    beki[0] = 1;
    for (int i = 0; i+1 < 50; ++i) beki[i+1] = beki[i] * 26;
    cout << rec(0, N, 0, 0) << endl;
}