けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 171 F - Strivore (600 点)

実は、元の文字列の形はどうでもよくて、文字列の長さだけが重要という!!!

問題へのリンク

問題概要

長さ  N の英子文字からなる文字列  S が与えられる。これに以下の操作をちょうど  K 回行ってできる文字列が何通り考えられるか、1000000007 で割ったあまりを求めよ。

  • 文字列の好きな箇所に好きな文字 (英子文字) を挿入する

制約

  •  1 \le N, K \le 10^{6}

考えたこと

まずはこの問題、「操作によって出来上がるものが何通りあるか?」という形の問題になっている。この手の問題では、まともに操作を追いかけてはいけない。異なる操作が結局同じ結果になるような重複を取り除く、という方針では頓挫することが多い。この手の問題で有力なのは、


文字列  T が与えられるので、操作によって文字列  T を作ることができるかどうかを判定せよ

という問題を最初に解く


ということだ。まずは判定問題を解くことによって、結果的に作ることが出来得る文字列についてわかりやすい特徴づけを与えることができることが多々ある。

判定問題

判定問題を言い換えると

  • 長さ  N + K の文字列  T が、その部分文字列として  S を含みますか?

ということになる。部分文字列というものを扱う上での考え方は以下の記事に書いた。

qiita.com

具体的には次のような貪欲法で解くことができる。


  • 文字列 T の index を i、文字列 S の index を j として、i = j = 0 に初期化する
  • i = 0, 1, 2, ... と見ていき、T[ i ] = S[ j ] を満たしたら j をインクリメントする
  • i が T の終端に達する前に、j が S の終端に達したら Yes、そうでなければ No

つまり、たとえば S = "abac" という文字列であったとき、T が以下のような形になることが必要十分条件になる。

f:id:drken1215:20200627074347p:plain

ここまでくると、実は S がどんな文字列であるのかは関係なくて、S の長さのみが重要であることがわかる。あとはこのような T を数え上げる問題を考えよう。方針として、大きく 3 通り考えてみることにする。

  1. S の文字を置く場所を考える
  2. DP を高速化する
  3. 形式的冪級数で考える (maspy さんの記事へ)

解法 (1): S の文字を置く場所を考える

これが自然で素朴でわかりやすいと思う。まず T の最後尾のところ (S の最終文字の後ろの任意文字でよいところ) だけ特殊だから、この部分の長さで場合分けする。

T の最後尾のところの長さを  k (= 0, 1, 2, \dots, K) とすると

  • 最後尾のところの場合の数:  26^{k} 通り
  • 最後尾と、S の最後の文字以外の部分について、
    • S の最後の文字以外の入る場所を選ぶ場合の和:  C(N + K - k - 1, N - 1) 通り
    • S の文字の入る場所以外について埋める場合の数:  25^{K - k} 通り

ということで、 C(N + K - k - 1, N - 1) \times 26^{k} \times 25^{K - k} 通りとなる。これを  k = 0, 1, 2, \dots, K について合算すればよい。計算量は  O(N + K) (適切に前処理した場合)。

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

int main() {
    int K;
    string S;
    cin >> K >> S;
    int N = S.size();

    BiCoef<mint> bc(2100000);
    mint res = 0;
    for (int k = 0; k <= K; ++k) {
        res += bc.com(N + K - k - 1, N - 1) 
        * modpow(mint(26), k) * modpow(mint(25), K - k);
    }
    cout << res << endl;
}

解法 (2): DP 高速化

僕は本番こっちだった。ちょっと本筋じゃない感じだけど、こういう迷走したところからも AC に繋げるのは重要かなと思う。僕はこんな感じの DP を考えた

  • dp[ i ][ j ] := T の i 文字目まで考えたときに、S の j 文字目までが登場するような場合の数

そうすると、遷移は次のようになる。

j < N のとき

  • dp[ i + 1 ][ j ] += dp[ i ][ j ] × 25
  • dp[ i + 1 ][ j + 1 ] += dp[ i ][ j ]

j == N のとき

  • dp[ i + 1 ][ j ] += dp[ i ][ j ] × 26

このままだと  O((N + K)K) の計算量を要するので工夫が必要だ。しかし、dp の式にすると、改めて S の長さのみが重要なことが浮き彫りになるので、何かしら綺麗な式で書けそうな気がしてくる。

上の式を眺めると、j == N の場合を無視すると、パスカルの三角形をなす漸化式に重みをつけたような形になっていることがわかる。よって、dp[ i ][ j ] のそれぞれの値は、 C(i, i - j)25^{i - j} という感じの値になりそうなのだ。この辺の感覚は、以下のツイートに綺麗にまとまっている。

概ねこんな感じで良さそうだけど、j == N の場合から DP 遷移が dp[ i + 1 ][ N + 1 ] とかに広がって行かずに、すべて dp[ i + 1 ][ N ] へと集約していくのをみると、次のように考えれば良さそう。


仮想的に、以下の遷移式を考える (j < N かどうかで場合分けしない)

  • dp2[ i + 1 ][ j ] += dp2[ i ][ j ] × 25
  • dp2[ i + 1 ][ j + 1 ] += dp2[ i ][ j ]

このとき、求めたい dp[ N + K ][ N ] の値は、

dp2[ N + K ][ N ] + dp2[ N + K ][ N + 1 ] + ... + dp2[ N + K ][ N + K ]

となる


ここで実験などにより、

  • dp2[ N + K ][ k ] =  C(N + K, N + K - k) \times 25^{N + K - k}

であることがわかるから、結局、 k = 0, 1, \dots, K に対して

  • dp2[ N + K ][ N + K - k ] =  C(N + K, k) \times 25^{k}

を合算すれば OK。

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

int main() {
    int K;
    string S;
    cin >> K >> S;
    int N = S.size();

    BiCoef<mint> bc(2100000);
    mint res = 0;
    for (int k = 0; k <= K; ++k) {
        res += bc.com(N + K, k) * modpow(mint(25), k);
    }
    cout << res << endl;
}

解法 (3):形式的冪級数で議論

maspy さんの記事に詳しい議論があるので、そちらを参考に。

maspypy.com