けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 033 E - Go around a Circle (赤色, 1500 点)

本番これを間に合わせられなかったのが悔しい

問題へのリンク

問題概要

円周を  N 等分して、それぞれの弧を赤か青に塗る方法のうち、

  • 'R' と 'B' のみからなる長さ  M の文字列  S が与えられて
  • 円周上のどの端点から出発しても弧を順番に左右どちらかに  M 回たどっていくことで  S で表される色列を実現できる

ようなものを数え上げよ。ただし回転して一致するものは別に数えることにする。

制約

  •  2 \le N \le 10^{5}
  •  1 \le M \le 10^{5}

考えたこと

問題の条件がとてもわけがわからないので、少しでもよい特徴づけを探すといい感じになるタイプの問題に見える。こういうのは必要条件をひたすら列挙したら、やがて十分条件に辿り着けるタイプの問題。

とりあえず S は適切な変換を施すことで、B でスタートするとしてよい。このときまずは

  • R が二箇所連続することはない

ということが言える。もし二箇所連続していたら、その間でスタートしたときに B で始められない。よって色塗り方法は B の個数

  • (a1 個、a2 個、...) ((a1 + 1) + (a2 + 1) + ... = N)

みたいな特徴づけができることになる。この a1, a2, ... の満たすべき条件を探る。まず ai は奇数でなければならない。なぜなら、S が最初 B が b 個連続する部分を処理しようと思ったとき、

  • ai が偶数だと、ai の途中部分からスタートするときに左右の B の個数が (偶数, 偶数) と (奇数, 奇数) の両方の場合の出現に対応しないといけなくて、b が偶数でも奇数でも詰んでしまうスタート位置がある

ということが言える。また、

  • b が奇数のときは、ai は b 以下の奇数
  • b が偶数のときは、ai は b+1 以下の奇数

が最初の B b 個に対処できるための必要十分であることも言える。さらに S が

  • B b個、R q個、B r個、R s個、...

となっているとき、途中が「B b個」みたいになっている部分について、円弧の任意の B 連続箇所に突入して b 個連続させてから R に入るようにしなければならないことは言える (スタート位置を任意に調節すればその状態に追い込める)。で、それができる条件は

  • b が偶数のときはなんでもよい
  • b が奇数のときは ai <= b (ai は奇数) が必要十分

であることも言える。以上をまとめるとようするに、円弧を x1 = a1 + 1, x2 = a2 + 1, ... と分ける方法のうち

  • x1, x2, ... は偶数
  • xi <= min(2b, S における B の連続数が奇数である部分の最小値)

を満たすものを数え上げる問題ということになる。ただ例外はあって、S が全部 B の場合は別途考える (この場合は R が連続しないことのみが必要十分条件)。

数え上げ (S が全部 B の場合以外)

N が奇数だったら最初からダメ。
N が偶数のとき、 N = N/2 として適切に問題を変換すると (x1, x2, ..., xK) を順に数珠状に並べる方法のうち、ある整数 v が存在して

  • x1 + x2 + ... + xK = N
  • 1 <= xi <= v

を満たすものを数える問題ということになる。これはある点を固定して、その左右に初めて R マスが登場する位置で場合分けすると、左右のその長さを r とすると数珠でなく数直線に関する問題になる。

  • dp[ i ] := 1 以上 v 以下の整数で i を作る方法の数

を求めておけば集計できる。

#include <iostream>
#include <vector>
#include <string>
using namespace std;

#define COUT(x) cout << #x << " = " << (x) << " (L" << __LINE__ << ")" << endl
template<class T1, class T2> ostream& operator << (ostream &s, pair<T1,T2> P)
{ return s << '<' << P.first << ", " << P.second << '>'; }
template<class T> ostream& operator << (ostream &s, vector<T> P)
{ for (int i = 0; i < P.size(); ++i) { if (i > 0) { s << " "; } s << P[i]; } return s; }


template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};



// 二項係数ライブラリ
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};



const int MAX = 501010;
const int MOD = 1000000007;
using mint = Fp<MOD>;

int N, M;
string S;

mint solve() {
    BiCoef<mint> bc(MAX);
    if (S[0] == 'R') {
        for (auto &c : S) {
            if (c == 'R') c = 'B';
            else c = 'R';
        }
    }
    vector<int> nums;
    for (int i = 0; i < M;) {
        int j = i+1;
        while (j < M && S[j] == S[i]) ++j;
        nums.push_back(j-i);
        i = j;
    }
    if (nums.size() > 1 && N % 2 == 1) return 0;

    int minv = 1, maxv = N;
    if (nums.size() == 1) minv = 2, maxv = N;
    else {
        N /= 2;
        maxv = nums[0] / 2 + 1;
        for (int i = 0; i+1 < nums.size(); i += 2) {
            if (nums[i] & 1) maxv = min(maxv, (nums[i] + 1) / 2);
        }
    }
    
    vector<mint> dp(N+1, 0), sdp(N+2, 0);
    dp[0] = 1, sdp[1] = 1;
    for (int i = 1; i <= N; ++i) {
        dp[i] = sdp[max(0, i + 1 - minv)] - sdp[max(0, i - maxv)];
        sdp[i+1] = sdp[i] + dp[i];
    }
    mint res = 0;
    if (nums.size() == 1) {
        for (int r = minv; r <= min(maxv, N); ++r) res += dp[N-r] * r;
        res += 1; // all B
    }
    else {
        for (int r = minv; r <= min(maxv, N); ++r) res += dp[N-r] * r * 2;
    }
    return res;
}

int main() {     
    while (cin >> N >> M >> S) {
        cout << solve() << endl;
    }
}