けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 108 D - AB (青色, 600 点)

こういうの最初は手で様子を掴むようにしているのだけど、どのタイミングで PC 実験開始しようか悩む。

問題概要

整数  N と 4 つの文字  c_{AA},c_{AB},c_{BA},c_{BB} が与えられる (いずれも "A" または "B")

すぬけ君は文字列  s を持っている。  s ははじめ "AB" である。すぬけ君は以下の 4 種類の操作を任意の順序で 0 回以上行うことができる。

  • 文字列  s 中の連続する 2 文字が "AA" である箇所の間に、文字  c_{AA} を挿入する
  • 文字列  s 中の連続する 2 文字が "AB" である箇所の間に、文字  c_{AB} を挿入する
  • 文字列  s 中の連続する 2 文字が "BA" である箇所の間に、文字  c_{BA} を挿入する
  • 文字列  s 中の連続する 2 文字が "BB" である箇所の間に、文字  c_{BB} を挿入する

 s の長さが  N になるまで操作を行ったあとの  s としてありうる文字列の個数を 1000000007 で割ったあまりを求めよ。

制約

  •  2 \le N \le 1000

考えたこと

4 つの文字のパターンとしては  2^{4} = 16 通りある。とりあえずいろいろ試してみる。

まず、

  | A B |
A | A A |
B | * * |

というパターン (* はなんでもよい) の場合は、"A...AAAAAAAB" にしかならないので 1 通りになることがわかる。同様に

  | A B |
A | * B |
B | * B |

というパターンの場合も、"ABBBBBBB...B" にしかならないので 1 通りになることがわかる。これで 16 通り中 8 通りはわかった。さて、ここから先は実験してみることにした。

実験

 N = 10 として、16 通りのパターンの結果を一通り出力してみた。実験に用いたコードは以下の感じ (modint 省略)。

出力結果は次のようになった。

AAAA: 1
BAAA: 34
ABAA: 128
BBAA: 128
AABA: 1
BABA: 128
ABBA: 34
BBBA: 34
AAAB: 1
BAAB: 34
ABAB: 1
BBAB: 1
AABB: 1
BABB: 128
ABBB: 1
BBBB: 1
mint naive(int N, string pat) {
    set<string> res, nres;
    res.insert("AB");
    for (int i = 0; i < N-2; ++i) {
        nres.clear();

        for (auto str : res) {
            for (int i = 0; i+1 < str.size(); ++i) {
                char c;
                string tmp = str.substr(i, 2);
                if (tmp== "AA") c = pat[0];
                else if (tmp == "AB") c = pat[1];
                else if (tmp == "BA") c = pat[2];
                else c = pat[3];
                string nstr = str.substr(0, i+1) + c + str.substr(i+1);
                nres.insert(nstr);
            }
        }
        swap(res, nres);
    }
    return mint(res.size());
}

int main() {
    // test
    int N = 10;
    for (int bit = 0; bit < 16; ++bit) {
        string pat = "";
        for (int i = 0; i < 4; ++i) {
            if (bit & (1<<i)) pat += "B";
            else pat += "A";
        }
        cout << pat << ": " << naive(N, pat) << endl;
    }
}

結果の解釈

実験結果を見ると、

という 3 通りの値しか出てこないことがわかった。個別に見るとこのようになる理由は納得いくものではあった。

計算量は  O(N) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

mint naive(int N, string pat) {
    set<string> res, nres;
    res.insert("AB");
    for (int i = 0; i < N-2; ++i) {
        nres.clear();

        for (auto str : res) {
            for (int i = 0; i+1 < str.size(); ++i) {
                char c;
                string tmp = str.substr(i, 2);
                if (tmp== "AA") c = pat[0];
                else if (tmp == "AB") c = pat[1];
                else if (tmp == "BA") c = pat[2];
                else c = pat[3];
                string nstr = str.substr(0, i+1) + c + str.substr(i+1);
                nres.insert(nstr);
            }
        }
        swap(res, nres);
    }
    return mint(res.size());
}

mint solve(int N, string pat) {
    if (N == 2) return 1;
    vector<mint> two(N+1, 1), fib(N+1, 1);
    for (int n = 1; n <= N; ++n) two[n] = two[n-1] * 2;
    for (int n = 2; n <= N; ++n) fib[n] = fib[n-1] + fib[n-2];
    if (pat[0] == 'A' && pat[1] == 'A') return 1;
    else if (pat[1] == 'B' && pat[3] == 'B') return 1;
    else if (pat == "ABAA" || pat == "BBAA" || pat == "BABA" || pat == "BABB") return two[N-3];
    else return fib[N-2];
}

int main() {
    // test
    /*
    int N = 10;
    for (int bit = 0; bit < 16; ++bit) {
        string pat = "";
        for (int i = 0; i < 4; ++i) {
            if (bit & (1<<i)) pat += "B";
            else pat += "A";
        }
        cout << pat << ": " << naive(N, pat) << endl;
    }
    */

    int N;
    string pat(4, '*');
    cin >> N >> pat[0] >> pat[1] >> pat[2] >> pat[3];
    cout << solve(N, pat) << endl;
}