けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 115 B - Plus Matrix (茶色, 400 点)

「決めてから、整合性を確認する」というタイプの問題の典型例ですね!

問題概要

 N \times N の非負整数を成分とする行列  C が与えられる。

すべての  (i,j) について  C_{i,j} = A_{i} + B_{j} を満たすような非負整数列  A_{1}, \dots, A_{N} B_{1}, \dots, B_{N} の組が存在するか判定し、存在するなら一つ出力せよ。

制約

  •  1 \le N \le 500

考えたこと

条件を満たすような行列  C であれば、

1 7 5
3 9 7

というようになっている。つまり、

  • 2 行目 1 列目の 3 は、1 行目 1 列目の 1 に +2 したもの
  • 2 行目 2 列目の 9 は、1 行目 2 列目の 7 に +2 したもの
  • 2 行目 3 列目の 7 は、1 行目 3 列目の 5 に +2 したもの

というように、2 行目の値は一律に 1 行目に同じ値を足し引きしたものとなっている。つまり、

 A_{2} = A_{1} + 2

という風にするのが適切だということだ。

同じ理屈で、次のようなことが言える。

  •  A_{1}, \dots, A_{N} の各隣接項の差分については、1 列目のみを見れば一意に決まる
  •  B_{1}, \dots, B_{N} の各隣接項の差分については、1 行目のみを見れば一意に決まる

ということがわかる。たとえばサンプル 1 の

4 3 5
2 1 3
3 2 4

というケースであれば、

  •  A_{2} = A_{1} - 2 (4 と 2 を比較)
  •  A_{3} = A_{2} + 1 (2 と 3 を比較)
  •  B_{2} = B_{1} - 1 (4 と 3 を比較)
  •  B_{3} = B_{2} + 2 (3 と 5 を比較)

という関係が成立することが必要だと考えることができるのだ。ただしこの時点では、 A B相対的な値の関係がわかるのみであって、値が完全に決まるわけではないことに注意しよう。

とりあえず、相対的な値の関係がわかるので、最小値をとるところを 0 に設定してあげることにしよう。たとえば上のケースであれば、1 列目が (4, 2, 3) なので

 A = (2, 0, 1)

としてあげるのが良さそうだ。同様に

 B = (1, 0, 2)

としてあげれば OK。これだと行列  C に対して「ちょうど 1 だけ足りない」という感じなので、A か B のどちらかに一律に 1 を足せば OK。B に 1 ずつ足すと

  •  A = (2, 0, 1)
  •  B = (2, 1, 3)

となる。

ダメな場合

逆に、

  •  A = (2, 0, 1)
  •  B = (1, 0, 2)

というように、一行目と一列目のみを見て決定した  A, B に対して、

 C_{i, j} A_{i} + B_{j} との差

が一定でない場合は "No" ということになる。

もう一つ、その差が負の値になる可能性 (=  A B に負の整数が混ざってしまう可能性) があるのではないかと気になってしまうかもしれない。しかしそれはあり得ない。なぜなら、上記のような  A, B の作り方をした場合、 A_{i} + B_{j} の最小値は 0 になるのだ。 C_{i, j} の値はすべて 0 以上であることが保証されているので、 C_{i, j} - (A_{i} + B_{j}) が負になることはあり得ない。

コード

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

int main() {
    int N; cin >> N;
    vector<vector<long long>> C(N, vector<long long>(N));
    for (int i = 0; i < N; ++i) for (int j = 0; j < N; ++j) cin >> C[i][j];
    vector<long long> X(N), Y(N);
        
    auto solve = [&]() -> bool {        
        long long mix = C[0][0], miy = C[0][0];
        for (int i = 0; i < N; ++i) chmin(mix, C[i][0]), chmin(miy, C[0][i]);
        for (int i = 0; i < N; ++i) X[i] = C[i][0] - mix, Y[i] = C[0][i] - miy;

        // 差分を一律に足す
        long long dif = C[0][0] - X[0] - Y[0];
        for (int i = 0; i < N; ++i) X[i] += dif;

        // 整合性を確認する
        for (int i = 0; i < N; ++i) 
            for (int j = 0; j < N; ++j) 
                if (X[i] + Y[j] != C[i][j])
                    return false;
        return true;
    };
    
    if (solve()) {
        cout << "Yes" << endl;
        for (int i = 0; i < N; ++i) cout << X[i] << " "; cout << endl;
        for (int i = 0; i < N; ++i) cout << Y[i] << " "; cout << endl;
    }
    else cout << "No" << endl;
}

AtCoder ARC 115 C - ℕ Coloring (茶色, 500 点)

これ茶色はさすがにびっくりした!

問題概要

整数  N が与えられる。

正の整数からなる数列  A_{1}, \dots, A_{N} であって、

  •  i j の約数であるような任意の  i, j に対して  A_{i} \neq A_{j}

という条件を満たすものを考える。そのような数列のうち、数列に登場する値の最大値が最小となるようなものを一つ求めよ。

考えたこと

たとえば  N = 24 のとき、1, 2, 4, 12, 24 というように、隣接する二要素が約数倍数関係であるようなものをとってくると、どの 2 つも約数倍数であるようになっている。

つまりこの場合、

  •  A_{1} = 1
  •  A_{2} = 2
  •  A_{4} = 3
  •  A_{12} = 4
  •  A_{24} = 5

というように、絶対に「5」までは必要になるのだ。

一般に、 N 以下の整数のうち、「素因数分解したときの指数の和の最大値 + 1」だけの値までは必ず必要になることが示せる。たとえば

 24 = 2^{3} \times 3^{1}

なので、24 を含むなら 3 + 1 + 1 = 5 までは必ず必要というわけだ。

十分性

逆に、 N 以下の整数のうち、「素因数分解したときの指数の和の最大値 + 1」が答えとなるような数列を具体的に作ることもできる。

具体的には、各  n に対して

としてあげれば、条件を満たすことになる!!!

 p q の約数であるとき、 p素因数分解したときの指数の和は、 q素因数分解したときの指数の和よりも真に小さいことからわかる。

コード

#include <bits/stdc++.h>
using namespace std;

int main() {
    auto calc = [&](long long n) -> long long {
        long long res = 1;
        if (n == 1) return res;
        for (long long p = 2; p * p <= n; ++p) {
            while (n % p == 0) {
                ++res;
                n /= p;
            }
        }
        if (n > 1) ++res;
        return res;
    };
    
    int N;
    cin >> N;
    for (long long n = 1; n <= N; ++n) cout << calc(n) << " ";
    cout << endl;
}

AtCoder ARC 115 D - Odd Degree (黄色, 600 点)

なんとか解けた。若干エスパー気味に解いた。

問題概要

頂点数  N、辺数  M の無向単純グラフが与えられる。

 k = 0, 1, \dots, N に対して、この誘導部分グラフ (頂点集合はそのまま、辺集合は部分集合) であって、次数が奇数の頂点が  k 個であるようなものの個数を 999244353 で割ったあまりを求めよ。

制約

  •  N, M \le 5000

考えたこと

まずは色々なグラフで試してみることにした。そうしていくうちに

  • 連結成分ごとに独立して考えて、得られた解を多項式としての掛け算をしたものが答えになる (それはそう)
  • 頂点数  N のパスグラフに対しては  ({}_{N}{\rm C}_{0}, 0, {}_{N}{\rm C}_{2}, 0, {}_{N}{\rm C}_{4}, \dots) が答え

という風になってることがわかってきた。さらにパスグラフでなくても、木であれば答えが一定になるっぽいことがわかって来た。 そしてさらに色んなケースを試すことで、

  • 連結であれば、辺数が 1 増えると、各  k に対する答えが全体的に 2 倍ずつになる

ということが見えて来た。つまり、連結グラフであれば、頂点数と辺数のみに依存して答えが求まるのだ。これを素直に実装すると通った。

背景

グラフの接続行列を用いた  {\rm F}_{2} 線形代数っぽい雰囲気の問題では、「全域木に帰着して考える」という方法が有効打になることが多いみたいだ。

今回の問題も、連結グラフに対して

  • 連結グラフに辺を 1 本追加すると、全体の答えが 2 倍になる
  • 木についての答えが  ({}_{N}{\rm C}_{0}, 0, {}_{N}{\rm C}_{2}, 0, {}_{N}{\rm C}_{4}, \dots) となる

ということを解決していけば OK。

前者 (全域木へと帰着)

こっちは比較的わかりやすい。追加する辺を  e としたとき、「次数が奇数となる頂点集合」を変えないように、

  •  e を含まない誘導部分グラフ
  •  e を含む誘導部分グラフ

との間に一対一対応を作ることができるのだ。具体的には、 e を含むサイクルをとることができる (全域木の基本サイクル) ので、そのサイクル上の各辺について「部分グラフに選ぶ」「部分グラフに選ばない」を反転していけば OK。

よって、連結グラフに辺  e を追加するとき、全体の答えは 2 倍になることがわかった。

後者 (木の場合)

まず奇数次数の頂点はかならず偶数個になることに注意する (有名問題)。そして、次のことが成立する!!!


 k が偶数のとき、 N 個の頂点から  k 個の頂点を選べば、それらの頂点の次数を奇数にして、それ以外の頂点の次数を偶数にするようなものは、ただ一つ存在する。


具体的にはこれは、葉から順に決まって行くイメージなのだ。各頂点の偶奇を決めてあげると、葉に接続している各辺に対して、その辺を部分グラフとして採用すべきかどうかが決まって行く。そして葉から内側への辺に対しても Greedy に採用すべきかどうかが決まって行くのだ。

よって、 ({}_{N}{\rm C}_{0}, 0, {}_{N}{\rm C}_{2}, 0, {}_{N}{\rm C}_{4}, \dots) が答えになる。

補足

maspy さんの話が面白かった

コード

各連結成分に対する解は明示的に得られるので、それらを多項式として掛け算していけば OK。

特に、「二分木のような計算順序」 (この記事を参照) を用いて FFT を活用していくことで、計算量は  O(N (\log N)^{2}) となる。

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator < (const Fp& r) const noexcept {
        return this->val < r.val;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

namespace NTT {
    long long modpow(long long a, long long n, int mod) {
        long long res = 1;
        while (n > 0) {
            if (n & 1) res = res * a % mod;
            a = a * a % mod;
            n >>= 1;
        }
        return res;
    }

    long long modinv(long long a, int mod) {
        long long b = mod, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        u %= mod;
        if (u < 0) u += mod;
        return u;
    }

    int calc_primitive_root(int mod) {
        if (mod == 2) return 1;
        if (mod == 167772161) return 3;
        if (mod == 469762049) return 3;
        if (mod == 754974721) return 11;
        if (mod == 998244353) return 3;
        int divs[20] = {};
        divs[0] = 2;
        int cnt = 1;
        long long x = (mod - 1) / 2;
        while (x % 2 == 0) x /= 2;
        for (long long i = 3; i * i <= x; i += 2) {
            if (x % i == 0) {
                divs[cnt++] = i;
                while (x % i == 0) x /= i;
            }
        }
        if (x > 1) divs[cnt++] = x;
        for (int g = 2;; g++) {
            bool ok = true;
            for (int i = 0; i < cnt; i++) {
                if (modpow(g, (mod - 1) / divs[i], mod) == 1) {
                    ok = false;
                    break;
                }
            }
            if (ok) return g;
        }
    }

    int get_fft_size(int N, int M) {
        int size_a = 1, size_b = 1;
        while (size_a < N) size_a <<= 1;
        while (size_b < M) size_b <<= 1;
        return max(size_a, size_b) << 1;
    }

    // number-theoretic transform
    template<class mint> void trans(vector<mint>& v, bool inv = false) {
        if (v.empty()) return;
        int N = (int)v.size();
        int MOD = v[0].getmod();
        int PR = calc_primitive_root(MOD);
        static bool first = true;
        static vector<long long> vbw(30), vibw(30);
        if (first) {
            first = false;
            for (int k = 0; k < 30; ++k) {
                vbw[k] = modpow(PR, (MOD - 1) >> (k + 1), MOD);
                vibw[k] = modinv(vbw[k], MOD);
            }
        }
        for (int i = 0, j = 1; j < N - 1; j++) {
            for (int k = N >> 1; k > (i ^= k); k >>= 1);
            if (i > j) swap(v[i], v[j]);
        }
        for (int k = 0, t = 2; t <= N; ++k, t <<= 1) {
            long long bw = vbw[k];
            if (inv) bw = vibw[k];
            for (int i = 0; i < N; i += t) {
                mint w = 1;
                for (int j = 0; j < t/2; ++j) {
                    int j1 = i + j, j2 = i + j + t/2;
                    mint c1 = v[j1], c2 = v[j2] * w;
                    v[j1] = c1 + c2;
                    v[j2] = c1 - c2;
                    w *= bw;
                }
            }
        }
        if (inv) {
            long long invN = modinv(N, MOD);
            for (int i = 0; i < N; ++i) v[i] = v[i] * invN;
        }
    }

    // for garner
    static constexpr int MOD0 = 754974721;
    static constexpr int MOD1 = 167772161;
    static constexpr int MOD2 = 469762049;
    using mint0 = Fp<MOD0>;
    using mint1 = Fp<MOD1>;
    using mint2 = Fp<MOD2>;
    static const mint1 imod0 = 95869806; // modinv(MOD0, MOD1);
    static const mint2 imod1 = 104391568; // modinv(MOD1, MOD2);
    static const mint2 imod01 = 187290749; // imod1 / MOD0;

    // small case (T = mint, long long)
    template<class T> vector<T> naive_mul 
    (const vector<T>& A, const vector<T>& B) {
        if (A.empty() || B.empty()) return {};
        int N = (int)A.size(), M = (int)B.size();
        vector<T> res(N + M - 1);
        for (int i = 0; i < N; ++i)
            for (int j = 0; j < M; ++j)
                res[i + j] += A[i] * B[j];
        return res;
    }

    // mint
    template<class mint> vector<mint> mul
    (const vector<mint>& A, const vector<mint>& B) {
        if (A.empty() || B.empty()) return {};
        int N = (int)A.size(), M = (int)B.size();
        if (min(N, M) < 30) return naive_mul(A, B);
        int MOD = A[0].getmod();
        int size_fft = get_fft_size(N, M);
        if (MOD == 998244353) {
            vector<mint> a(size_fft), b(size_fft), c(size_fft);
            for (int i = 0; i < N; ++i) a[i] = A[i];
            for (int i = 0; i < M; ++i) b[i] = B[i];
            trans(a), trans(b);
            vector<mint> res(size_fft);
            for (int i = 0; i < size_fft; ++i) res[i] = a[i] * b[i];
            trans(res, true);
            res.resize(N + M - 1);
            return res;
        }
        vector<mint0> a0(size_fft, 0), b0(size_fft, 0), c0(size_fft, 0);
        vector<mint1> a1(size_fft, 0), b1(size_fft, 0), c1(size_fft, 0);
        vector<mint2> a2(size_fft, 0), b2(size_fft, 0), c2(size_fft, 0);
        for (int i = 0; i < N; ++i)
            a0[i] = A[i].val, a1[i] = A[i].val, a2[i] = A[i].val;
        for (int i = 0; i < M; ++i)
            b0[i] = B[i].val, b1[i] = B[i].val, b2[i] = B[i].val;
        trans(a0), trans(a1), trans(a2), trans(b0), trans(b1), trans(b2);
        for (int i = 0; i < size_fft; ++i) {
            c0[i] = a0[i] * b0[i];
            c1[i] = a1[i] * b1[i];
            c2[i] = a2[i] * b2[i];
        }
        trans(c0, true), trans(c1, true), trans(c2, true);
        static const mint mod0 = MOD0, mod01 = mod0 * MOD1;
        vector<mint> res(N + M - 1);
        for (int i = 0; i < N + M - 1; ++i) {
            int y0 = c0[i].val;
            int y1 = (imod0 * (c1[i] - y0)).val;
            int y2 = (imod01 * (c2[i] - y0) - imod1 * y1).val;
            res[i] = mod01 * y2 + mod0 * y1 + y0;
        }
        return res;
    }

    // long long
    vector<long long> mul_ll
    (const vector<long long>& A, const vector<long long>& B) {
        if (A.empty() || B.empty()) return {};
        int N = (int)A.size(), M = (int)B.size();
        if (min(N, M) < 30) return naive_mul(A, B);
        int size_fft = get_fft_size(N, M);
        vector<mint0> a0(size_fft, 0), b0(size_fft, 0), c0(size_fft, 0);
        vector<mint1> a1(size_fft, 0), b1(size_fft, 0), c1(size_fft, 0);
        vector<mint2> a2(size_fft, 0), b2(size_fft, 0), c2(size_fft, 0);
        for (int i = 0; i < N; ++i)
            a0[i] = A[i], a1[i] = A[i], a2[i] = A[i];
        for (int i = 0; i < M; ++i)
            b0[i] = B[i], b1[i] = B[i], b2[i] = B[i];
        trans(a0), trans(a1), trans(a2), trans(b0), trans(b1), trans(b2);
        for (int i = 0; i < size_fft; ++i) {
            c0[i] = a0[i] * b0[i];
            c1[i] = a1[i] * b1[i];
            c2[i] = a2[i] * b2[i];
        }
        trans(c0, true), trans(c1, true), trans(c2, true);
        static const long long mod0 = MOD0, mod01 = mod0 * MOD1;
        vector<long long> res(N + M - 1);
        for (int i = 0; i < N + M - 1; ++i) {
            int y0 = c0[i].val;
            int y1 = (imod0 * (c1[i] - y0)).val;
            int y2 = (imod01 * (c2[i] - y0) - imod1 * y1).val;
            res[i] = mod01 * y2 + mod0 * y1 + y0;
        }
        return res;
    }
};

// Binomial Coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

using Graph = vector<vector<int>>;
int V, E;
void dfs(const Graph &G, int v, vector<bool> &seen) {
    seen[v] = true;
    ++V;
    for (auto e: G[v]) {
        ++E;
        if (seen[e]) continue;
        dfs(G, e, seen);
    }
}

int main() {
    int N, M;
    cin >> N >> M;
    BiCoef<mint> bc(N+1);
    Graph G(N);
    for (int i = 0; i < M; ++i) {
        int a, b;
        cin >> a >> b;
        --a, --b;
        G[a].push_back(b);
        G[b].push_back(a);
    }

    // 各連結成分ごとに求める
    priority_queue<pair<int,vector<mint>>, vector<pair<int,vector<mint>>>, greater<pair<int,vector<mint>>>> que;
    vector<bool> seen(N, false);
    for (int v = 0; v < N; ++v) {
        if (seen[v]) continue;
        V = 0, E = 0;
        dfs(G, v, seen);
        E /= 2;
        mint two = 1;
        for (int i = 0; i < E-V+1; ++i) two *= 2;
        vector<mint> f(V+1, 0);
        for (int i = 0; i <= V; ++i) {
            if (i % 2 == 1) continue;
            f[i] = bc.com(V, i) * two;
        }
        que.push({f.size(), f});
    }

    // 二分木のような計算順序で FFT
    while (que.size() >= 2) {
        auto f = que.top().second; que.pop();
        auto g = que.top().second; que.pop();
        auto h = NTT::mul(f, g);
        que.push({h.size(), h});
    }
    auto res = que.top().second;
    for (int i = 0; i <= N; ++i) cout << res[i] << endl;
}

AtCoder ARC 115 E - LEQ and NEQ (黄色, 700 点)

間に合わなかった!!!悔しい!!!

問題概要

長さ  N の数列  A_{1}, \dots, A_{N} が与えられます。以下の条件を満たすような、長さ  N の数列  X_{1}, \dots, X_{N} の個数を 998244353 で割ったあまりを答えよ。

  •  1 \le X_{i} \le A_{i}
  •  X_{i} \neq X_{i+1}

制約

  •  2 \le N \le 5 \times 10^{5}
  •  1 \le A_{i} \le 10^{9}

考えたこと

 X_{i} \neq X_{i+1} という条件は扱いづらいので、包除原理でやると良さそう。

 N-1 個の隙間のうち、いくつかの箇所に「=」を入れると、「=」で繋がれた区間は一つに潰れる感じになる。そしてその区間については「区間内の  A_{i} の最小値」に置き換えてあげる。このように縮退してできる数列に対して、単純に掛け算を取れば OK。

というわけで、基本的には区間ごとに分割していくタイプの DP で扱えそうだ。

  • dp[i] ← 数列の最初の i 個についての場合の数

こうすると、次のような DP によって  O(N^{2}) にはなる。

dp[i] += dp[j] × (A[j:i] の最小値) ×  (-1)^{j-i-1}

とりあえず  (-1)^{j-i-1} という部分がややこしいので、

dp[i] (-1)^{i} × dp[i]

と置き直すことで、次のように変形した。

dp[i] =  - \sum_{j = 1}^{i-1} (dp[j] × (A[j:i] の最小値))

この段階でとりあえずサンプル 2 が合うことを確かめておいた。

DP 高速化

たとえば  A = (1, 2, 5, 4, 3) のとき

  • dp[4] = -(dp[0] * 1 + dp[1] * 2 + dp[2] * 4 + dp[3] * 4)

という感じになっている。ここから dp[5] を考えると、 A_{4} = 3 を考慮に加えることになって、こんな感じになる。

  • dp[5] = -(dp[0] * 1 + dp[1] * 2 + dp[2] * 3 + dp[3] * 3 + dp[4] * 3)

この式は分解すると、

dp[5] = -(dp[0] * 1 + dp[1] * 2) - (dp[2] + dp[3] + dp[4]) * 3
= dp[2] - (dp[2] + dp[3] + dp[4]) * 3

という風に理解できる。つまり、(1, 2, 4, 4) に対して 3 がどこに挿入されるかを考えて、挿入される箇所の前の部分では過去の dp 値、挿入される箇所の後の部分では「累積和 ×  A の値」という感じになっているのだ。

よって、挿入される箇所が特定できれば、DP を高速化できる。僕はこの挿入箇所の特定を、

の遅延評価セグメント木を用意して二分探索する、みたいなことをやって TLE してしまった。

しかしよく考えたら、stack を使えば線形時間でできるのであった!!!

コード

計算量は  O(N) になる。

#include <bits/stdc++.h>
using namespace std;
using pll = pair<long long, long long>;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator < (const Fp& r) const noexcept {
        return this->val < r.val;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N;
    cin >> N;
    vector<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    stack<pll> st;
    st.push({0, 0});
    vector<mint> dp(N+1, 0), sdp(N+2, 0);
    dp[0] = 1, sdp[1] = 1;
    for (int i = 1; i <= N; ++i) {
        long long up = A[i-1];
        while (!st.empty() && st.top().first >= up) st.pop();
        long long num = st.top().second;
        st.push({up, i});
            
        if (num > 0) dp[i] += dp[num];
        dp[i] -= (sdp[i] - sdp[num]) * up;
        sdp[i+1] = sdp[i] + dp[i];
    }
    mint res = dp[N];
    if (N % 2 == 1) res = -res;
    cout << res << endl;
}

その他の解法

座標圧縮だったり、遅延評価セグメント木を使ったりなどでもできる模様。

AtCoder ABC 196 C - Doubled (灰色, 300 点)

いわゆる、 O(\sqrt{N}) まで調べれば十分というタイプの問題だね。最近そのタイプの問題が流行っている気がする!

問題概要

十進法表記で偶数桁で、かつ、その前半と後半とが文字列として等しいようなものを「良い整数」と呼ぶことにします。

 1 以上  N 以下の整数のうち、「良い整数」が何個あるかを求めよ。

制約

  •  1 \le N \lt 10^{12}

考えたこと

一見すると、 1 から  N までループを回して、「良い整数」が何個あるかを調べないといけないように感じてしまう。その場合の計算量は  O(N \log N) となる (整数  n が良い整数かどうかの判定は  O(\log n) かかる)。

このままでは間に合わない。

しかしよく考えてみると、たとえば 1234512345 は「良い整数」だが、これは「12345」という整数と同一視できる。つまり、良い整数はその情報を失わずに、より小さい整数に情報を圧縮できるのだ。

よって、このように圧縮した整数の方を全探索すればよいことに気づく。つまり、

  • 1 番目:11
  • 2 番目:22
  • 3 番目:33
  • ...
  • 9 番目:99
  • 10 番目:1010
  • 11 番目:1111
  • 12 番目:1212
  • ...
  • 99 番目:9999
  • 100 番目:100100
  • ...

という風に良い整数は列挙できる。これが  N を超えるまでやっていく感じ。最悪でも 100000 までやれば十分 (100000100000 は  10^{12} より大きい) なので、全探索で間に合う。

なお、 n 番目の「良い整数」を求める関数は、たとえば次のように実装できる。

long long reconstruct(long long n) {
    long long val = 1, nn = n;
    while (nn) {
        val *= 10;
        nn /= 10;
    }
    return n * val + n;
}

この関数は、次のようなことをしている

  • n = 12345 であるとき、
  • その桁数を  d として val =  10^{d} とする (val = 100000 となる)
  • 求める整数は n * val + n で求められる (1234512345 になる)

コード

#include <bits/stdc++.h>
using namespace std;

long long reconstruct(long long n) {
    long long val = 1, nn = n;
    while (nn) {
        val *= 10;
        nn /= 10;
    }
    return n * val + n;
}

int main() {
    long long N;
    cin >> N;
    long long res = 0;
    for (long long n = 1; n <= 1000000; ++n) {
        if (reconstruct(n) <= N) ++res;
        else break;
    }
    cout << res << endl;
}