けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 304 F - Shift Table (青色, 525 点)

メビウス関数を用いた約数系包除原理を使いこなそう!

問題概要 (表現改)

文字 '.' と '#' のみからなる長さ  N の文字列  S が与えられる。今、文字 '#' をいくつか '.' に書き換えることによって、文字列  S が周期的文字列となるようにしたい。

なお、 S が周期的文字列であるとは、 N より真に小さい  N の約数  M が存在して、 S_{i+M} = S_{i} ( 0 \le i \lt N-M) を満たすことをいうものとする。

書き換えて得られる周期的文字列の個数を 998244353 で割ったあまりを求めよ。

制約

  •  2 \le N \le 10^{5}

考えたこと:最小周期を考える

たとえば次の文字列を見てみよう。

#.#.#.#.#.#.#.#.

この文字列は周期 2 だが、周期 4 であるとも言えるし、周期 6 であるとも言える。つまり、同じ文字列であっても、いくつかの周期が考えられることが厄介だ。

そこで、理想的には、最小周期で場合分けして数え上げることにしたい。最小周期が  d であるような文字列の個数を  f(d) と表すと、答えは  \displaystyle \sum_{d | N, d \neq N} f(d) だ。

最小周期が難しいので、先に周期を考える

しかし、ここで壁にぶつかるのだ。最小周期がちょうど  d であるような文字列を数え上げるのは簡単ではない。最小周期が 6 であるような文字列の個数を考えるためには、上の "#.#.#.#.#.#.#.#." のような文字列は除外しなければならない。この文字列は確かに周期 6 をもつが、最小周期は 2 なのだ。

そこで方針を変えよう。最小周期が  d であるような文字列を数え上げるのではなく、やっぱり、先に周期が  d であるような文字列の個数を数えることにするのだ。これは比較的容易に求められる。

周期が  d であるような文字列の個数を  F(d) と表すことにしよう。 F(d) は次のように求められる。


 S_{i} = S_{i+d} = S_{i+2d} = \dots = '#'

を満たすような  i ( i = 0, 1, \dots, d-1) の個数を  m としたとき、

 F(d) = 2^{m}

である


本当に求めたいのは  f(d) の方だが、それは難しいので、代わりに  F(d) の方を求めたというわけだ。

メビウスの反転公式と、約数系包除原理

最後に、(約数系) 包除原理を活用して、 F(d) (周期  d をもつ文字列の個数) から  f(d) (最小周期が  d である文字列の個数) を求める方法を考えよう。

たとえば、 N = 12 としてみよう。12 の約数として  d = 6 がある。周期が  d = 6 であるとき、その最小周期としては  1, 2, 3, 6 の 4 通りがありうる。よって、次の式が成り立つ。

 F(6) = f(1) + f(2) + f(3) + f(6)

同様に、次のように立式できる。

  •  F(1) = f(1)
  •  F(2) = f(1) + f(2)
  •  F(3) = f(1) + f(3)
  •  F(4) = f(1) + f(2) + f(4)
  •  F(6) = f(1) + f(2) + f(3) + f(6)
  •  F(12) = f(1) + f(2) + f(3) + f(4) + f(6) + f(12)

これらは  f(1), f(2), f(3), f(4), f(6), f(12) についての連立方程式ともみなせる。実は、次の定理が成り立つ。メビウスの反転公式などと呼ばれている。ここで、 \mu(n)メビウス関数である。


【メビウスの反転公式】

 F(n) = \displaystyle \sum_{d|n} f(d) のとき、

 f(n) = \displaystyle \sum_{d|n} \mu(\frac{n}{d}) F(d)


この反転公式を活用して  f(n) の値を求める処理は、約数系包除原理と呼ばれる方法そのものでもある。この反転した式に  n = 1, 2, 3, 4, 6, 12 を適用した式を具体的に書き下してみると、「確かに!」になると思う。

  •  f(1) = F(1)
  •  f(2) = F(2) - F(1)
  •  f(3) = F(3) - F(1)
  •  f(4) = F(4) - F(2)
  •  f(6) = F(6) - F(2) - F(3) + F(1)
  •  f(12) = F(12) - F(6) - F(4) + F(2)

ここで、私たちが求めたい値は  \displaystyle \sum_{d | N, d \neq N} f(d) であることを思い出そう。

たとえば  N = 12 の場合を考えると、求めたい値は  f(1) + f(2) + f(3) + f(4) + f(6) である。ここで、

 F(12) = f(1) + f(2) + f(3) + f(4) + f(6) + f(12)

であることを考慮すると、求めたい値は  F(12) - f(12) に一致する。よって、メビウスの反転公式を適用することで、求めたい値は

 F(12) - f(12) =  F(12) - (F(12) - F(6) - F(4) + F(2)) = F(6) + F(4) - F(2)

と表せる。 F(2), F(4), F(6) の値は容易に求められる。

一般の  N についても同様に、求めたい値は

 \displaystyle \sum_{d | N, d \neq N} f(d)
 = \displaystyle F(N) - f(N)
 = \displaystyle F(N) - \displaystyle \sum_{d|N} \mu(\frac{N}{d}) F(d) (メビウスの反転公式より)
 = \displaystyle -\displaystyle \sum_{d|N, d \neq N} \mu(\frac{N}{d}) F(d)

によって計算できる。計算量は、 N の約数の個数を  D として、 O(N(D + \log \log N)) となる。

補足 1:メビウス関数値の列挙

エラトステネスの篩を用いると、メビウス関数値  \mu(1), \mu(2), \mu(3), \dots, \mu(N) の値を列挙するのを  O(N \log \log N) の計算量で効率よくできます。

詳しいことは次の記事に書きました。この記事では、メビウスの反転公式が「累積和の逆演算」を意味することなども書いていますので、気になる方はぜひ読んでみてください。

qiita.com

補足 2:メビウスの反転公式の証明

次の記事に証明があります。

manabitimes.jp

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() noexcept : val(0) { }
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const noexcept { return val; }
    constexpr int get_mod() const noexcept { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp &r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const noexcept {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const noexcept {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &r, long long n) noexcept {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD> &r) noexcept {
        return r.inv();
    }
};

// エラトステネスの篩を用いた、メビウス関数値列挙、高速素因数分解、約数列挙
struct Eratos {
    vector<bool> isprime;
    vector<int> mebius;
    vector<int> min_factor;
 
    Eratos(int MAX) : isprime(MAX+1, true),
                      mebius(MAX+1, 1),
                      min_factor(MAX+1, -1) {
        isprime[0] = isprime[1] = false;
        min_factor[0] = 0, min_factor[1] = 1;
        for (int i = 2; i <= MAX; ++i) {
            if (!isprime[i]) continue;
            mebius[i] = -1;
            min_factor[i] = i;
            for (int j = i*2; j <= MAX; j += i) {
                isprime[j] = false;
                if ((j / i) % i == 0) mebius[j] = 0;
                else mebius[j] = -mebius[j];
                if (min_factor[j] == -1) min_factor[j] = i;
            }
        }
    }
 
    // prime factorization
    vector<pair<int,int>> prime_factors(int n) {
        vector<pair<int,int> > res;
        while (n != 1) {
            int prime = min_factor[n];
            int exp = 0;
            while (min_factor[n] == prime) {
                ++exp;
                n /= prime;
            }
            res.push_back(make_pair(prime, exp));
        }
        return res;
    }
 
    // enumerate divisors
    vector<int> divisors(int n) {
        vector<int> res({1});
        auto pf = prime_factors(n);
        for (auto p : pf) {
            int n = (int)res.size();
            for (int i = 0; i < n; ++i) {
                int v = 1;
                for (int j = 0; j < p.second; ++j) {
                    v *= p.first;
                    res.push_back(res[i] * v);
                }
            }
        }
        return res;
    }
};

int main() {
    const int MOD = 998244353;
    using mint = Fp<MOD>;
    
    // 入力
    int N;
    string S;
    cin >> N >> S;
    Eratos er(N + 10);
    
    // 約数系包除原理
    mint res = 0;
    const auto &div = er.divisors(N);
    for (auto d : div) {
        if (d == N) continue;

        // '#' か '.' かを選べるマスの個数 num を求める
        vector<bool> can(d, true);
        for (int i = 0; i < N; ++i) {
            if (S[i] == '.') can[i % d] = false;
        }
        long long num = 0;
        for (int i = 0; i < d; ++i) if (can[i]) ++num;
        
        // 包除原理
        res -= mint(2).pow(num) * er.mebius[N / d];
    }
    cout << res << endl;
}