けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 109 E - 1D Reversi Builder (橙色, 700 点)

これを本番間に合わせられたなかったのは辛かった...
あと、数え上げパートがあんなにスマートにはできなかった。無限に  O(N^{2}) から落とせなかった...

問題概要

黒石さんと白石さんは、一列に並んだ  N 個のマスからなる盤面を使って遊んでいます。 マスにはそれぞれ 1 から  N の整数が順番に振られていて、マス  s に印がつけられています。

まず、黒石さんは、各マスについて独立に、黒か白を等確率で選んで塗ります。その後、マス  s にマスの色と同じ色の石を置きます。

黒石さんと白石さんは、この盤面と無限個の黒い石と白い石を使ったゲームをします。このゲームでは、黒石さんから始めて、黒石さんと白石さんが交互に次の手順で石を置いていきます。

  • 石が置かれているマスと隣接している空きマスをひとつ選ぶ。マス  i を選んだとする
  • マス  i に、マスと同じ色の石をおく
  • 置いた石と同じ色の石がマス  i 以外に置かれているとき、そのうちマス  i に最も近い石と、マス  i の間にあるすべての石の色をマス  i の色に変更する

空きマスが存在しなくなったときにゲームが終了します。

黒石さんはゲーム終了時の黒い石の個数を最大化するために最適な行動をし、白石さんはゲーム終了時の白い石の個数を最大化するために最適な行動をします。

 s = 1, \dots , N のそれぞれの場合について、ゲーム終了時の黒い石の個数の期待値を mod 998244353 で求めてください。

制約

  •  1 \le N \le 2 \times 10^{5}

エスパー

すぬけさんも「こんなに解くのすごいな」と解説放送で言っていた。でもコンテスト後の TL を見ると、エスパーで解き倒した方もたくさんいたっぽい。そういうこともできるようにならないと...となった。まずはエスパーでやってみる。

サンプル 2 ( N = 10) のケースについて、まずは期待値のままだと見えづらいので個数に直してみる ( 2^{10} = 1024 をかけてみる)

0: 5120
1: 5120
2: 5126
3: 5166
4: 5390
5: 5390
6: 5166
7: 5126
8: 5120
9: 5120

なんとなく近い値ばかりなので差分をとってみる!

0: 5120
0->1: 0
1->2: 6
2->3: 40
3->4: 224
4->5: 0
5->6: 998244129
6->7: 998244313
7->8: 998244347
8->9: 0

今回の値は、 s が小さいところと大きいところで左右対称になっている ことが想定されるため、実質的には  s = 4 まで考えれば OK。そうすると、これらの差分は 0, 6, 40, 224 となっていると言える。これらを素因数分解してみよう。

  •  5120 = 2^{10} \times 5 (= 2^{9} \times 10)
  •  6 = 2 \times 3
  •  40 = 2^{3} \times 5
  •  224 = 2^{5} \times 7

となっている。確かに明らかに規則性がありますね!!  N = 19 についても試すと確信に変わる。最初の方の階差数列は  N = 10 の場合と同じだ。また、

  •  4980736 = 2^{18} \times 19

となる。

0: 4980736
0->1: 0
1->2: 6
2->3: 40
3->4: 224
4->5: 1152
5->6: 5632
6->7: 26624
7->8: 122880
8->9: 557056
->10: 997687297
10->11: 998121473
11->12: 998217729
12->13: 998238721
13->14: 998243201
14->15: 998244129
15->16: 998244313
16->17: 998244347
17->18: 0

以上より、

  •  s = 0 のとき、 2^{N-1} \times N
  •  s が小さい範囲での、 s (\ge 1) から  s+1 への差分は  2^{2s-1} \times (2s+1)

とすれば通りそうだと予想できる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N;
    cin >> N;
    vector<mint> res(N);
    res[0] = modpow(mint(2), N-1) * N;
    for (int s = 0; s+1 < N; ++s) {
        mint diff = 0;
        if (s > 0) diff = modpow(mint(2), s*2-1) * (s*2+1);
        res[s+1] = res[s] + diff;
    }
    for (int i = 0, j = N-1; i < j; ++i, --j) res[j] = res[i];
    for (int i = 0; i < N; ++i) cout << res[i] / modpow(mint(2), N) << endl;
}

正攻法

なによりまずは、ゲームパートを解かないといけない。黒石を o、白石を x (また、黒マスを o、白マスを x で表す) で表すことにする。ここは実験して様子を見ることにした。実験することでわかってきたことは、ゲームにおいて常に

  • oo...o (すべて o)
  • xx...x (すべて x)
  • oo...oxx...x (変わり目が 1 箇所)
  • xx...xoo...o (変わり目が 1 箇所)

のいずれかになる。そして、ゲームの勝ち負けは以下の通りだとわかった。最終状態を (o の個数, x の個数) で表す。

  • 左端も右端も o のときは、最後はすべて o になる
  • 左端も右端も x のときは、最後はすべて x になる
  • 左端から  A 個が o で、右端から  B 個が x であるとする
    • 始点  s A 個の o の中に含まれるときは、 (N-B, B) となる
    • 始点  s B 個の x の中に含まれるときは、 (A, N-A) となる
    • 始点から見て、左端から  A 番目の o への距離 ( a) と、右端から  B 番目の x への距離 ( b) としたとき
      •  a \le b のとき、最終結果は  (A, N - A) となる
      •  a \gt b のとき、最終結果は  (N - B, B) となる
  • 左端から  A 個が x で、右端から  B 個が o であるとする
    • 同様

数え上げ

以上から、少なくとも  O(N^{3}) の計算量をかけてよいならば、数え上げができる。これを  O(N) O(N \log N) まで落とすのは大変だ...僕はここで、指数関数の和を求めたり、いもす法などを駆使したりすることで、なんとか  O(N) にした。あまりにも煩雑なので大変だった。

atcoder.jp

ここでは、公式解説の巧みなアイディアをなんとか吸収していきたい。アイディアは

  • あらゆる初期盤面を考えたときの、最終盤面の o の個数 ( P) の総和 ( \sum P とする, これが求めたい値)
  • あらゆる初期盤面を考えたときの、最終盤面の x の個数 ( Q) の総和 ( \sum Q とする)

とする。求めたいのは  \sum P だが、

 \sum P + \sum Q = N 2^{N}

なので、代わりに  \sum R = \sum P - \sum Q を求めておけば

 \sum P = \frac{\sum R + (\sum P + \sum Q)}{2} = \frac{\sum R}{2} + N 2^{N-1}

と求められる。実は  P = Q となる盤面ペアが大量にあるので、計算が楽になるみたいなのだ。ダメになるペアはこういうやつ

  • ooo x .. ( L 個) .. s .. ( L 個) .. o xxx
  • xxx o .. ( L 個) .. s .. ( L 個) .. x ooo

前者は  (s+L+2, N-s-L-2) となり、後者は  (N-s+L+1, s-L-1) となる。これらのペアにおいては

 R = (s+L+2) + (N-s+L+1) - (N-s-L-2) - (s-L-1) = 2(2L + 3)

となる。各  L = 0, 1, 2, \dots に対して  2^{2L+1} 通りのペアがありうるので、結局、 s \le N/2 においては

 \frac{\sum R}{2}  = \sum_{L = 0}^{s-2} (2L+3) \times 2^{2L+1}

となる ( s \ge N/2 の範囲では  L の範囲の表式が変わる)。よって  s = 1, 2, \dots, N/2-1 において、 s s+1 との間の答えの差分は

 (2s+1) 2^{2s-1}

となる。これは、先ほどのエスパー結果に一致する。