けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 039 C - Division by Two with Something (黄色, 800 点)

この回の前の回の LCMs といい、約数系包除がこの時期流行ってたのかな。

問題概要

整数  N, X が与えられる。
 0 以上  X 以下のすべての整数  k に対し、 k に以下の操作を繰り返すことによって次に  k に戻るまでの操作回数 (戻らない場合 0) を足し合わせた値を 998244353 で割ったあまりを求めyお。

  • 現在の整数が奇数なら、 1 を引いて  2 で割る
  • そうでなければ、 2 で割って  2^{N−1} を足す

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

まずは小さい  N であれこれ試すことにした。 N = 3 のときは

000 → 100 → 110 → 111 → 011 → 001 → (000) 
010 → 101 → (010)

という 2 系統に分解できることがわかる。つまり、回数 6 が 6 個と、回数 2 が 2 個になっている。 N = 4 のときは

0000 → 1000 → 1100 → 1110 → 1111 → 0111 → 0011 → 0001 → (0000)
0010 → 1001 → 0100 → 1010 → 1101 → 0110 → 1011 → 0101 → (0010)

という 2 系統に分解できる。それぞれともに周期 8 のサイクルとなっている。 N = 5 のときもやってみると、

  • 00000 を含む周期 10 のサイクル
  • 00010 を含む周期 10 のサイクル
  • 00100 を含む周期 10 のサイクル
  • 01010 を含む周期 2 のサイクル

の 4 系統に分解できることがわかる。この時点でとりあえず、どんな整数  k についても、周期が  2N の約数になるようなサイクルになることを確信した。もう少し詳しく調べることにした。

操作の言い換え

操作は次のようになっている。たとえば  N = 6 k = 110101 としたとき、 k の左側にビット反転したものを付け加えると

001010 110101

という風になる。このとき、操作は 1 個分左にスライドしたものに移る。たとえば

110101 → 011010 → 101101 → ...

という具合だ。このようになる理由は簡単。操作をよく考えると、

  • abcd0 → 1abcd
  • abcd1 → 0abcd

というものになっていることからわかる。よって、整数  k の周期は次のように求められる。


 k = l \times m であるとして、 m を奇数とする

整数  k を長さ  l ずつ  m 分割したとき、交互にビット反転したものとなっているとき、整数  k は周期  2l をもつ


たとえば  N = 6 k = 011001 とすると、(01)(10)(01) が交互にビット反転しているので周期 2 だとわかる。

これを満たすような最小の整数  l が実際の周期となる。

約数系包除へ

以上から、 N の各約数  l (ただし  \frac{N}{l} が奇数) に対して、

  •  f(l) = 周期  l をもつような整数  k であって、 X 以下であるようなものの個数

が求められばよいことになった。実際はたとえば周期 5 をもつような整数  k は 5 の倍数の周期も持つことを考慮しないといけない。単純にやるとダブルカウントしてしまうことに注意。このような状況を扱うためには、約数系包除 (約数での高速メビウス変換) が使えるものと相場は決まっている。われわれは  f(l) を求めることに集中すれば十分。

 f(l) の値は、 X の最初の  l 桁分の値を  A としたとき、 A A+1 のどちらかになる。後者になるのは、 A とそのビット反転したものを交互に並べたものが  X 以下となる場合だ。その判定は  O(N) でできる。 f(l) の計算をするべき  l の個数が約数個 (大雑把に見積もっても  O(\sqrt{N}) 個) しかないので、全体の計算量は  O(N \sqrt{N}) でできる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

void mebius(vector<mint> &v) {
    int N = v.size() - 1;
    vector<bool> isprime(N+1, true);
    for (int i = 2; i * i <= N; ++i) {
        if (!isprime[i]) continue;
        for (int j = i*2; j <= N; j += i) isprime[j] = false;
    }
    for (int p = 2; p <= N; ++p) {
        if (!isprime[p]) continue;

        for (long long i = N/p * p; i >= p; i -= p) {
            v[i] -= v[i/p];
        }
    }
}

vector<long long> calc_divisor(long long n) {
    vector<long long> res;
    for (long long i = 1LL; i*i <= n; ++i) {
        if (n % i == 0) {
            res.push_back(i);
            long long j = n / i;
            if (j != i) res.push_back(j);
        }
    }
    sort(res.begin(), res.end());
    return res;
}

int main() {
    int N;
    string X;
    cin >> N >> X;
    const auto &div = calc_divisor(N);

    vector<mint> v(N+1, 0);
    for (auto l : div) {
        if (N / l % 2 == 0) continue;
        mint num = 0;
        for (int i = 0; i < l; ++i) num = num * 2 + (int)(X[i]-'0');

        // X.substr(l) で始まるやつが X 以下かどうか
        int rev = 0;
        string Y = "", str = X.substr(0, l), rstr = "";
        for (int i = 0; i < str.size(); ++i) {
            if (str[i] == '0') rstr += '1';
            else rstr += '0';
        }
        while (Y.size() < X.size()) {
            if (rev) Y += rstr;
            else Y += str;
            rev = 1 - rev;
        }
        if (Y <= X) num += 1;
        v[l] = num;
    }

    // 高速メビウス変換
    mebius(v);

    // 集計
    mint res = 0;
    for (auto l : div) {
        if (N / l % 2 == 0) continue;
        res += v[l] * l * 2;
    }
    cout << res << endl;
}