けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

JOIG 2021 D - 展覧会 2 (AOJ 0704, 難易度 6)

競プロ典型 90 問 001 - Yokan Party」とよく似た問題!

前提知識

問題概要

 N 枚の絵が一直線上に順に並んでいます。 i 枚目の絵は座標  X_{i} の位置にあり、その価値は  V_{i} です。

今これらの絵から  M 枚の絵を選びます。このとき次の制約を満たさなければなりません。

  • どの 2 枚の絵も、距離が  D 以上離れなれければならない

「選んだ  M 枚の絵の価値の最小値」として考えられる最大値を求めてください。ただし制約条件を満たすように  M 枚を選ぶことが不可能である場合は -1 と出力してください。

制約

  •  1 \le M \le N \le 10^{5}

判定問題に帰着する

今回の問題は「最小値を最大化してください」という問題になっています。そのように「最大値を最小化」や「最小値や最大化」な問題では、二分探索法が有効なことが多いですね。まずは今回の問題を判定問題化*してみましょう!


選んだすべての絵の価値が  x 以上となるように、 M 枚の絵を選べるかどうかを判定してください。ただしどの 2 枚の絵の間隔も  D 以上離す必要があります。


 x が小さいときは "Yes" になります (ならないことも) し、 x が大きいときは "No" になります。二分探索法によってその境界が特定できます

この判定の問題の答えが "Yes" となる最大の  x が答えです。

具体的には次のコードのようにできます。絵の間隔としてありうる最大値を  Y として、二分探索法の反復回数は  O(\log Y) 回となります。

long long low = -1, high = 1LL<<30;
while (high - low > 1) {
    long long x = (low + high) / 2;
    if (絵の価値をすべて x にできる) low = x;
    else high = x;
}
return x;

 

判定問題を解く

まずは絵を左から順に並び替えておきます。

このとき判定問題を解くのは、次のような Greedy 解法でできます。絵を左から順に見ていき、

  • 前回選んだ絵から距離が  D 以上離れている
  • 価値が  x 以上ある

という条件をともに満たす最左の絵を選ぶことを繰り返していけばよいでしょう。そのように絵を選んでいき、 M 枚選べたら "Yes"、 M 枚に達しないうちに絵の右端に達してしまったら "No" と判定できます。

具体的な処理については、下のコードを参考にしましょう。判定問題を解くのは  O(N) の計算量でできます。

よって全体として  O(N (\log N + \log Y)) の計算量で解けます。

コード

そもそも  M 枚を選べない場合もあります。その場合であっても下のコードのように low = -1 と初期化しておくことで統一的に解けます。

#include <bits/stdc++.h>
using namespace std;
const long long INF = 1LL<<40;

int main() {
    long long N, M, D;
    cin >> N >> M >> D;
    vector<long long> X(N), V(N);
    for (int i = 0; i < N; ++i) cin >> X[i] >> V[i];

    // 一直線上に左から順に並び替える (ここでは id を並び替える)
    vector<int> ids(N);
    iota(ids.begin(), ids.end(), 0);
    sort(ids.begin(), ids.end(), [&](int i, int j) {
        return X[i] < X[j];
    });

    // 二分探索
    long long low = -1, high = INF;
    while (high - low > 1) {
        long long x = (low + high) / 2;

        int con = 0;  // 選んだ枚数
        long long prev = -INF;  // 最後に選んだ絵の位置
        for (auto id : ids) {
            if (X[id] - prev >= D && V[id] >= x) {
                ++con; 
                prev = X[id];
            }
        }
        if (con >= M) low = x;
        else high = x;
    }
    cout << low << endl;
}

JOIG 2021 C - イルミネーション 2 (AOJ 0703, 難易度 4)

落ち着いて整理して考えましょう。問題自体は「累積和」が使える良い問題ですね!

問題概要

 N 個の電球を一列に並べていて、オンオフ状態が  A_{1}, A_{2}, \dots, A_{N} であるような状態を作りたいとします。ただし  A_{i} = 1 i 番目の電球をオンにしたいことを表し、 A_{i} = 0 i 番目の電球をオフにしたいことを表します。

初期状態では、すべての電球がオフの状態です。あなたはまず、左から連続する何個かの電球をオンにすることができます。すべてオフのままでもよいですし、すべてをオンにしてもよいです。

その後、1 つの電球のオンオフ操作 (オンの電球はオフにして、オフの電球はオンにする) を実行していきます。このオンオフ操作の回数として考えられる最小値を答えてください。

制約

  •  1 \le N \le 2 \times 10^{5}

問題を整理する

まず、「最初に左から何個分の電球をオンにするか」を決めると、その後の操作回数が決まることに注意しましょう。たとえば

  •  A = (0, 1, 1, 0, 0, 1) に対して、
  • 左から 4 個分の電球をあらかじめオンにする

としたとしましょう。このとき操作回数は  A = (0, 1, 1, 0, 0, 1) (1, 1, 1, 1, 0, 0) とを比較して、相異なる箇所の個数に一致します。この場合は 3 回です。

0 1 1 0 0 1
1 1 1 1 0 0

よってまず次のような解法が考えられるでしょう。

  • 「最初にオンにする電球の個数」を全探索して、
  • それぞれについて「相異なる箇所の個数」を調べる

しかしこれは  O(N^{2}) の計算量となり、満点は得られません。

前処理で高速化

このようなとき、前処理で高速化するのが有効です。今回は次のような配列を用意しましょう。

  • left[i] A_{1}, A_{2}, \dots, A_{N} のうち、左から  i 個分の中の、オフである電球の個数
  • right[j] A_{1}, A_{2}, \dots, A_{N} のうち、右から  j 個分の中の、オンである電球の個数

これらの配列自体は  O(N) の計算量で計算できます。そしてこれらの配列が求められていると、なんと先ほどの探索が簡単に実行できます。

「最初に左から連続でオンにする個数」を  i 個としたとき、その後の必要な操作回数は

left[i] + right[N-i]

と簡単に表せるのです!! たとえば先ほどの例 ( A = (0, 1, 1, 0, 0, 1) i = 4) の場合、下図のようになります。

よって次の解法で答えが求められます。今度は「最初に左から連続でオンにする個数」を決めたときの操作回数が  O(1) で求められるので、計算量は  O(N) となります。


  1. 配列 leftright を求める
  2. i = 0, 1, ..., N に対する left[i] + right[N-i] の最小値を求める

コード

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N;
    cin >> N;
    vector<int> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    // 前処理
    vector<int> left(N+1, 0), right(N+1, 0);
    for (int i = 0; i < N; ++i) {
        left[i+1] = left[i] + (A[i] == 0);
        right[i+1] = right[i] + (A[N-i-1] == 1);
    }

    // 最適化
    int res = N;  // 理論上の上限値
    for (int i = 0; i <= N; ++i) {
        res = min(res, left[i] + right[N-i]);
    }
    cout << res << endl;
}

AtCoder ABC 245 E - Wrapping Chocolate (水色, 500 点)

これ!!!
ABC 091 C - 2D Plane 2N Point とほとんど同じ!!
ただ制約が大きいので、貪欲法を高速化する必要がありますね。

問題概要

 N 個のチョコレートと、 M 個の箱があります。

  •  i 番目のチョコレートはサイズ  A_{i} \times B_{i} であり、
  •  j 番目の箱はサイズ  C_{j} \times D_{j} です

チョコレート  i は、 A_{i} \le C_{j} かつ  B_{i} \le D_{j} を満たすとき、箱  j に入れることができます。

すべてのチョコレートをいずれかの箱に入れた状態にすることができるかどうか判定してください。

ただし、1 つの箱には 1 つのチョコレートしか入れることができません。

制約

  •  1 \le N \le M \le 2 \times 10^{5}

貪欲法

貪欲法自体は、以前記事に書いたとおりです!!
最大マッチングを求めて、その値が  N に一致するかどうかを判定すればよいでしょう。

drken1215.hatenablog.com

箱を縦  C_{j} が小さい順に見ていきます。そして、チョコレートの縦  A_{i} C_{j} 以下であるようなチョコレートのうち、

「横の長さ  B_{i} D_{j} を超えない範囲で最大となるようなチョコレート  i

を選んで、そのチョコレート  i を箱  j に入れていきます。そのチョコレート  i は今後使えなくなります。

このアルゴリズムは、各箱に対して、該当するチョコレートを選ぼうとすると  O(NM) の計算量となります。このままでは間に合いません。

データ構造を使う

ここでは std::multiset を利用して高速化しましょう。

 j について考えているとき、 A_{i} \le C_{j} を満たすチョコレートについての  B_{i} をかきあつめた集合を multiset S で管理していたとします。

このとき、この集合の要素のうち、 D_{j} を超えない範囲で最大のものは S.upper_bound(D[j]) で取得できるイテレータをデクリメントすることで取得できます。

これによって全体の計算量は  O(N + M \log N) となります。

コード

#include <bits/stdc++.h>
using namespace std;
using pll = pair<long long, long long>;

int main() {
    int N, M;
    cin >> N >> M;
    vector<pll> cho(N), box(M);
    for (int i = 0; i < N; ++i) cin >> cho[i].first;
    for (int i = 0; i < N; ++i) cin >> cho[i].second;
    for (int i = 0; i < M; ++i) cin >> box[i].first;
    for (int i = 0; i < M; ++i) cin >> box[i].second;

    // A, C が小さい順にソート
    sort(cho.begin(), cho.end());
    sort(box.begin(), box.end());

    // C が小さい順に見ていく
    multiset<int> S;
    int res = 0;
    int i = 0;
    for (int j = 0; j < M; ++j) {
        // A[i] <= C[j] となる範囲の i をすべて動かす
        while (i < N && cho[i].first <= box[j].first) {
            S.insert(cho[i].second);
            ++i;
        }

        // その範囲で最大値を取り出す
        if (S.empty()) continue;
        auto it = S.upper_bound(box[j].second);
        if (it == S.begin()) continue;
        --it;
        S.erase(it);
        ++res;
    }
    if (res == N) cout << "Yes" << endl;
    else cout << "No" << endl;
}

AtCoder ABC 245 D - Polynomial division (緑色, 400 点)

問題を見て「めっちゃ数学やん!!なにこれ!??」となった人は多いと思う!!!

でも落ち着いて整理して取り組めば解けるので、落ち着くことが大事そう。

もしくは、ライブラリで殴る!!!!!!

問題概要

 3 つの多項式

  •  N 次の多項式  A(x) = A_{N}x^{N} + A_{N-1}x^{N-1} + \dots + A_{0}
  •  M 次の多項式  B(x) = B_{M}x^{M} + B_{M-1}x^{M-1} + \dots + B_{0}
  •  N+M 次の多項式  C(x) = C_{N+M}x^{N+M} + C_{N+M-1}x^{N+M-1} + \dots + C_{0}

があって、 A(x)B(x) = C(x) という関係を満たしています。

入力として、多項式  A(x) C(x) が与えられますので、 B(x) を復元してください。

制約

  •  1 \le N, M \le 100

具体例で考える

この手の問題でいきなり抽象的な数式変形で解ける人はそれでよくて、そうでない場合は具体例を手で動かして様子を見るのが大事そうです!!

たとえばサンプル 1 を例にとります。

 A(x)=x+2,  C(x)=2x^{3}+8x^{2}+14x+12 に対して、まず最高次の項 ( x 2x^{3}) を合わせるためには、

 B(x) = 2x^{2} + \dots

とする必要があることがわかります。このとき、

 A(x)(2x^{2}) = 2x^{3} + 4x^{2}

となります。よって  A(x)B(x) C(x) に一致するためには、

 C(x) - A(x)(2x) = 4x^{2} + 14x + 12

の分だけ補充する必要があります。ここで再び  A(x) = x + 2 4x^{2} + 14x + 12 の最高次の項を比較すると、さっきの  B(x) = 2x^{2} + \dots の「 \dots」の部分の先頭が  4x であることが必要だとわかります。つまり

 B(x) = 2x^{2} + 4x + \dots

とする必要があることがわかります。そしてこのとき

 (4x^{2} + 14x + 12) - A(x)(4x) = 6x + 12

の分だけ補充する必要があることがわかるので、また最高次の項を比較して、 B(x) = 2x^{2} + 4x + \dots の「 \dots」の部分の先頭が  6 だとわかります。つまり、

 B(x) = 2x^{2} + 4x + 6 + \dots

とする必要があることがわかります。このとき、

 (6x + 12) - A(x)\times 6 = 0

となり、ちょうどぴったり  A(x) \times B(x) = C(x) となります。

以上より、 B(x) = 2x^{2} + 4x + 6 と求められました。

なお高校数学の数学 IA で「多項式の除算」を学んだことのある方は、以上の手続きを「筆算」でやったことを思い出す方もいるでしょう!!

 

数式で一般化

以上の手計算を、一般の入力ケースに通用するアルゴリズムに落とし込みましょう。

まず初期状態では  B(x) = 0 としておきます。最初に  A(x), C(x) の最高次の項である  A_{N}x^{N} C_{N+M}x^{N+M} とを比較することで、 B(x), C(x) を次のように更新します。

  •  B(x) +=  (C_{N+M} / A_{N}) x^{M}
  •  C(x) -=  A(x) (C_{N+M} / A_{N}) x^{M}

より具体的には次のような感じです。

B[M] = C[N+M] / A[N];
C[N+M] -= A[N] * B[M];
C[N+M-1] -= A[N-1] * B[M];
C[N+M-2] -= A[N-2] * B[M];
...
C[M] -= A[0] * B[M];

そうすると  C(x) N+M-1 次の多項式となります。そしてふたたび  A(x), C(x) の最高次の項を比較して、 B(x), C(x) を次のように更新します。

  •  B(x) +=  (C_{N+M-1} / A_{N}) x^{M-1}
  •  C(x) -=  A(x) (C_{N+M-1} / A_{N}) x^{M-1}

以下同様に、 A(x), C(x) の最高次の項を比較した結果を  B(x), C(x) に反映する作業を繰り返していけばよいでしょう。

より一般に、 B(x) i 次の項を求めるときは、次のようにします。

B[i] = C[N+i] / A[N];
C[N+i] -= A[N] * B[i];
C[N-1+i] -= A[N-1] * B[i];
C[N-2+i] -= A[N-2] * B[i];
...
C[i] -= A[0] * B[i];

数式で書くと複雑ですが、具体例をなんとかコードに落とし込むようにすることで、少しずつ納得できるかなと思います。

 

コード

計算量は  O(NM) となります。

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N, M;
    cin >> N >> M;
    vector<long long> A(N + 1), C(N + M + 1);
    for (int i = 0; i <= N; ++i) cin >> A[i];
    for (int i = 0; i <= N + M; ++i) cin >> C[i];

    // A の最高次項と C の最高次項を常に比較して B に加算する
    vector<long long> B(M + 1, 0);
    for (int i = M; i >= 0; --i) {
        // そのときの C の最高次係数を A の最高次係数で割る
        B[i] = C[i + N] / A[N];

        // C から、新たな B の増加分を A との積を引く
        for (int j = N; j >= 0; --j) {
            C[j + i] -= A[j] * B[i];
        }
    }

    // 出力
    for (int i = 0; i <= M; ++i) cout << B[i] << " ";
    cout << endl;
}

 

別解:ライブラリで殴る

多項式の除算は、形式的冪級数ライブラリを持っていれば、それを貼るだけで解けます。しかも  S = \max(M, N) として、 O(S \log S) の計算量で解けます!!!

github.com

このライブラリは、たとえば mod. 998244353 などで割り算をするものなので、 B(x) = 2x - 1 などは

 B(x) = 2x + 998244352

などとなります。よって、 B(x) が負の値だと思われる場合 (具体的には 100000 より大きい場合) には、998244353 を引いて補正します。

以下のコードで、main 関数の短さに注目!!!

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

namespace NTT {
    long long modpow(long long a, long long n, int mod) {
        long long res = 1;
        while (n > 0) {
            if (n & 1) res = res * a % mod;
            a = a * a % mod;
            n >>= 1;
        }
        return res;
    }

    long long modinv(long long a, int mod) {
        long long b = mod, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        u %= mod;
        if (u < 0) u += mod;
        return u;
    }

    int calc_primitive_root(int mod) {
        if (mod == 2) return 1;
        if (mod == 167772161) return 3;
        if (mod == 469762049) return 3;
        if (mod == 754974721) return 11;
        if (mod == 998244353) return 3;
        int divs[20] = {};
        divs[0] = 2;
        int cnt = 1;
        long long x = (mod - 1) / 2;
        while (x % 2 == 0) x /= 2;
        for (long long i = 3; i * i <= x; i += 2) {
            if (x % i == 0) {
                divs[cnt++] = i;
                while (x % i == 0) x /= i;
            }
        }
        if (x > 1) divs[cnt++] = x;
        for (int g = 2;; g++) {
            bool ok = true;
            for (int i = 0; i < cnt; i++) {
                if (modpow(g, (mod - 1) / divs[i], mod) == 1) {
                    ok = false;
                    break;
                }
            }
            if (ok) return g;
        }
    }

    int get_fft_size(int N, int M) {
        int size_a = 1, size_b = 1;
        while (size_a < N) size_a <<= 1;
        while (size_b < M) size_b <<= 1;
        return max(size_a, size_b) << 1;
    }

    // number-theoretic transform
    template<class mint> void trans(vector<mint>& v, bool inv = false) {
        if (v.empty()) return;
        int N = (int)v.size();
        int MOD = v[0].getmod();
        int PR = calc_primitive_root(MOD);
        static bool first = true;
        static vector<long long> vbw(30), vibw(30);
        if (first) {
            first = false;
            for (int k = 0; k < 30; ++k) {
                vbw[k] = modpow(PR, (MOD - 1) >> (k + 1), MOD);
                vibw[k] = modinv(vbw[k], MOD);
            }
        }
        for (int i = 0, j = 1; j < N - 1; j++) {
            for (int k = N >> 1; k > (i ^= k); k >>= 1);
            if (i > j) swap(v[i], v[j]);
        }
        for (int k = 0, t = 2; t <= N; ++k, t <<= 1) {
            long long bw = vbw[k];
            if (inv) bw = vibw[k];
            for (int i = 0; i < N; i += t) {
                mint w = 1;
                for (int j = 0; j < t/2; ++j) {
                    int j1 = i + j, j2 = i + j + t/2;
                    mint c1 = v[j1], c2 = v[j2] * w;
                    v[j1] = c1 + c2;
                    v[j2] = c1 - c2;
                    w *= bw;
                }
            }
        }
        if (inv) {
            long long invN = modinv(N, MOD);
            for (int i = 0; i < N; ++i) v[i] = v[i] * invN;
        }
    }

    // for garner
    static constexpr int MOD0 = 754974721;
    static constexpr int MOD1 = 167772161;
    static constexpr int MOD2 = 469762049;
    using mint0 = Fp<MOD0>;
    using mint1 = Fp<MOD1>;
    using mint2 = Fp<MOD2>;
    static const mint1 imod0 = 95869806; // modinv(MOD0, MOD1);
    static const mint2 imod1 = 104391568; // modinv(MOD1, MOD2);
    static const mint2 imod01 = 187290749; // imod1 / MOD0;

    // small case (T = mint, long long)
    template<class T> vector<T> naive_mul 
    (const vector<T>& A, const vector<T>& B) {
        if (A.empty() || B.empty()) return {};
        int N = (int)A.size(), M = (int)B.size();
        vector<T> res(N + M - 1);
        for (int i = 0; i < N; ++i)
            for (int j = 0; j < M; ++j)
                res[i + j] += A[i] * B[j];
        return res;
    }

    // mint
    template<class mint> vector<mint> mul
    (const vector<mint>& A, const vector<mint>& B) {
        if (A.empty() || B.empty()) return {};
        int N = (int)A.size(), M = (int)B.size();
        if (min(N, M) < 30) return naive_mul(A, B);
        int MOD = A[0].getmod();
        int size_fft = get_fft_size(N, M);
        if (MOD == 998244353) {
            vector<mint> a(size_fft), b(size_fft), c(size_fft);
            for (int i = 0; i < N; ++i) a[i] = A[i];
            for (int i = 0; i < M; ++i) b[i] = B[i];
            trans(a), trans(b);
            vector<mint> res(size_fft);
            for (int i = 0; i < size_fft; ++i) res[i] = a[i] * b[i];
            trans(res, true);
            res.resize(N + M - 1);
            return res;
        }
        vector<mint0> a0(size_fft, 0), b0(size_fft, 0), c0(size_fft, 0);
        vector<mint1> a1(size_fft, 0), b1(size_fft, 0), c1(size_fft, 0);
        vector<mint2> a2(size_fft, 0), b2(size_fft, 0), c2(size_fft, 0);
        for (int i = 0; i < N; ++i)
            a0[i] = A[i].val, a1[i] = A[i].val, a2[i] = A[i].val;
        for (int i = 0; i < M; ++i)
            b0[i] = B[i].val, b1[i] = B[i].val, b2[i] = B[i].val;
        trans(a0), trans(a1), trans(a2), trans(b0), trans(b1), trans(b2);
        for (int i = 0; i < size_fft; ++i) {
            c0[i] = a0[i] * b0[i];
            c1[i] = a1[i] * b1[i];
            c2[i] = a2[i] * b2[i];
        }
        trans(c0, true), trans(c1, true), trans(c2, true);
        static const mint mod0 = MOD0, mod01 = mod0 * MOD1;
        vector<mint> res(N + M - 1);
        for (int i = 0; i < N + M - 1; ++i) {
            int y0 = c0[i].val;
            int y1 = (imod0 * (c1[i] - y0)).val;
            int y2 = (imod01 * (c2[i] - y0) - imod1 * y1).val;
            res[i] = mod01 * y2 + mod0 * y1 + y0;
        }
        return res;
    }

    // long long
    vector<long long> mul_ll
    (const vector<long long>& A, const vector<long long>& B) {
        if (A.empty() || B.empty()) return {};
        int N = (int)A.size(), M = (int)B.size();
        if (min(N, M) < 30) return naive_mul(A, B);
        int size_fft = get_fft_size(N, M);
        vector<mint0> a0(size_fft, 0), b0(size_fft, 0), c0(size_fft, 0);
        vector<mint1> a1(size_fft, 0), b1(size_fft, 0), c1(size_fft, 0);
        vector<mint2> a2(size_fft, 0), b2(size_fft, 0), c2(size_fft, 0);
        for (int i = 0; i < N; ++i)
            a0[i] = A[i], a1[i] = A[i], a2[i] = A[i];
        for (int i = 0; i < M; ++i)
            b0[i] = B[i], b1[i] = B[i], b2[i] = B[i];
        trans(a0), trans(a1), trans(a2), trans(b0), trans(b1), trans(b2);
        for (int i = 0; i < size_fft; ++i) {
            c0[i] = a0[i] * b0[i];
            c1[i] = a1[i] * b1[i];
            c2[i] = a2[i] * b2[i];
        }
        trans(c0, true), trans(c1, true), trans(c2, true);
        static const long long mod0 = MOD0, mod01 = mod0 * MOD1;
        vector<long long> res(N + M - 1);
        for (int i = 0; i < N + M - 1; ++i) {
            int y0 = c0[i].val;
            int y1 = (imod0 * (c1[i] - y0)).val;
            int y2 = (imod01 * (c2[i] - y0) - imod1 * y1).val;
            res[i] = mod01 * y2 + mod0 * y1 + y0;
        }
        return res;
    }
};

// Binomial Coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

// Formal Power Series
template <typename mint> struct FPS : vector<mint> {
    using vector<mint>::vector;
 
    // constructor
    FPS(const vector<mint>& r) : vector<mint>(r) {}
 
    // core operator
    inline FPS pre(int siz) const {
        return FPS(begin(*this), begin(*this) + min((int)this->size(), siz));
    }
    inline FPS rev() const {
        FPS res = *this;
        reverse(begin(res), end(res));
        return res;
    }
    inline FPS& normalize() {
        while (!this->empty() && this->back() == 0) this->pop_back();
        return *this;
    }
 
    // basic operator
    inline FPS operator - () const noexcept {
        FPS res = (*this);
        for (int i = 0; i < (int)res.size(); ++i) res[i] = -res[i];
        return res;
    }
    inline FPS operator + (const mint& v) const { return FPS(*this) += v; }
    inline FPS operator + (const FPS& r) const { return FPS(*this) += r; }
    inline FPS operator - (const mint& v) const { return FPS(*this) -= v; }
    inline FPS operator - (const FPS& r) const { return FPS(*this) -= r; }
    inline FPS operator * (const mint& v) const { return FPS(*this) *= v; }
    inline FPS operator * (const FPS& r) const { return FPS(*this) *= r; }
    inline FPS operator / (const mint& v) const { return FPS(*this) /= v; }
    inline FPS operator << (int x) const { return FPS(*this) <<= x; }
    inline FPS operator >> (int x) const { return FPS(*this) >>= x; }
    inline FPS& operator += (const mint& v) {
        if (this->empty()) this->resize(1);
        (*this)[0] += v;
        return *this;
    }
    inline FPS& operator += (const FPS& r) {
        if (r.size() > this->size()) this->resize(r.size());
        for (int i = 0; i < (int)r.size(); ++i) (*this)[i] += r[i];
        return this->normalize();
    }
    inline FPS& operator -= (const mint& v) {
        if (this->empty()) this->resize(1);
        (*this)[0] -= v;
        return *this;
    }
    inline FPS& operator -= (const FPS& r) {
        if (r.size() > this->size()) this->resize(r.size());
        for (int i = 0; i < (int)r.size(); ++i) (*this)[i] -= r[i];
        return this->normalize();
    }
    inline FPS& operator *= (const mint& v) {
        for (int i = 0; i < (int)this->size(); ++i) (*this)[i] *= v;
        return *this;
    }
    inline FPS& operator *= (const FPS& r) {
        return *this = NTT::mul((*this), r);
    }
    inline FPS& operator /= (const mint& v) {
        assert(v != 0);
        mint iv = modinv(v);
        for (int i = 0; i < (int)this->size(); ++i) (*this)[i] *= iv;
        return *this;
    }
    inline FPS& operator <<= (int x) {
        FPS res(x, 0);
        res.insert(res.end(), begin(*this), end(*this));
        return *this = res;
    }
    inline FPS& operator >>= (int x) {
        FPS res;
        res.insert(res.end(), begin(*this) + x, end(*this));
        return *this = res;
    }
    inline mint eval(const mint& v){
        mint res = 0;
        for (int i = (int)this->size()-1; i >= 0; --i) {
            res *= v;
            res += (*this)[i];
        }
        return res;
    }
    inline friend FPS gcd(const FPS& f, const FPS& g) {
        if (g.empty()) return f;
        return gcd(g, f % g);
    }

    // advanced operation
    // df/dx
    inline friend FPS diff(const FPS& f) {
        int n = (int)f.size();
        FPS res(n-1);
        for (int i = 1; i < n; ++i) res[i-1] = f[i] * i;
        return res;
    }

    // \int f dx
    inline friend FPS integral(const FPS& f) {
        int n = (int)f.size();
        FPS res(n+1, 0);
        for (int i = 0; i < n; ++i) res[i+1] = f[i] / (i+1);
        return res;
    }

    // inv(f), f[0] must not be 0
    inline friend FPS inv(const FPS& f, int deg) {
        assert(f[0] != 0);
        if (deg < 0) deg = (int)f.size();
        FPS res({mint(1) / f[0]});
        for (int i = 1; i < deg; i <<= 1) {
            res = (res + res - res * res * f.pre(i << 1)).pre(i << 1);
        }
        res.resize(deg);
        return res;
    }
    inline friend FPS inv(const FPS& f) {
        return inv(f, f.size());
    }

    // division, r must be normalized (r.back() must not be 0)
    inline FPS& operator /= (const FPS& r) {
        assert(!r.empty());
        assert(r.back() != 0);
        this->normalize();
        if (this->size() < r.size()) {
            this->clear();
            return *this;
        }
        int need = (int)this->size() - (int)r.size() + 1;
        *this = ((*this).rev().pre(need) * inv(r.rev(), need)).pre(need).rev();
        return *this;
    }
    inline FPS& operator %= (const FPS &r) {
        assert(!r.empty());
        assert(r.back() != 0);
        this->normalize();
        FPS q = (*this) / r;
        return *this -= q * r;
    }
    inline FPS operator / (const FPS& r) const { return FPS(*this) /= r; }
    inline FPS operator % (const FPS& r) const { return FPS(*this) %= r; }

    // log(f) = \int f'/f dx, f[0] must be 1
    inline friend FPS log(const FPS& f, int deg) {
        assert(f[0] == 1);
        FPS res = integral(diff(f) * inv(f, deg));
        res.resize(deg);
        return res;
    }
    inline friend FPS log(const FPS& f) {
        return log(f, f.size());
    }

    // exp(f), f[0] must be 0
    inline friend FPS exp(const FPS& f, int deg) {
        assert(f[0] == 0);
        FPS res(1, 1);
        for (int i = 1; i < deg; i <<= 1) {
            res = res * (f.pre(i<<1) - log(res, i<<1) + 1).pre(i<<1);
        }
        res.resize(deg);
        return res;
    }
    inline friend FPS exp(const FPS& f) {
        return exp(f, f.size());
    }

    // pow(f) = exp(e * log f)
    inline friend FPS pow(const FPS& f, long long e, int deg) {
        long long i = 0;
        while (i < (int)f.size() && f[i] == 0) ++i;
        if (i == (int)f.size()) return FPS(deg, 0);
        if (i * e >= deg) return FPS(deg, 0);
        mint k = f[i];
        FPS res = exp(log((f >> i) / k, deg) * e, deg) * modpow(k, e) << (e * i);
        res.resize(deg);
        return res;
    }
    inline friend FPS pow(const FPS& f, long long e) {
        return pow(f, e, f.size());
    }

    // sqrt(f), f[0] must be 1
    inline friend FPS sqrt_base(const FPS& f, int deg) {
        assert(f[0] == 1);
        mint inv2 = mint(1) / 2;
        FPS res(1, 1);
        for (int i = 1; i < deg; i <<= 1) {
            res = (res + f.pre(i << 1) * inv(res, i << 1)).pre(i << 1);
            for (mint& x : res) x *= inv2;
        }
        res.resize(deg);
        return res;
    }
    inline friend FPS sqrt_base(const FPS& f) {
        return sqrt_base(f, f.size());
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N, M;
    cin >> N >> M;
    FPS<mint> A(N + 1), C(N + M + 1);
    for (int i = 0; i <= N; ++i) cin >> A[i];
    for (int i = 0; i <= N + M; ++i) cin >> C[i];

    // 求める
    FPS<mint> B = C / A;

    // 出力
    for (int i = 0; i <= M; ++i) {
        cout << (B[i].val <= 1000000 ? B[i].val : B[i].val - MOD) << " ";
    }
    cout << endl;
}

AtCoder ABC 245 C - Choose Elements (茶色, 300 点)

EDPC C - Vacation と良く似た問題だと思う!!

あと、幅が狭いグリッドでは DP が疑われることが結構多い!

問題概要

長さが  N の数列が 2 つ ( A_{0}, \dots, A_{N-1} B_{0}, \dots, B_{N-1}) 与えられます。

 i = 0, 1, \dots, N-1 に対して、 A_{i} B_{i} のいずれかを選ぶことで、新たに数列  W_{0}, \dots, W_{N-1} を作ります。

こうしてできる数列が  i = 0, 1, \dots, N-2 に対して  |W_{i} - W_{i+1}| \ge K を満たすようにできるかどうかを判定してください。

制約

  •  1 \le N \le 2 \times 10^{5}
  •  0 \le K \le 10^{9}

考えたこと

たとえば  A = (9, 8, 3, 7, 2) B = (1, 6, 2, 9, 5) K = 4 に対しては、下図のようなグラフに対して、左側から出発して右側へと到達できる方法があるかどうかを判定する問題だと言えます。

 K = 4 なので、差が  4 以下の部分にのみ矢印を引いています。

Union-Find は嘘解法

「左側と右側がつながれば Yes、途切れたら No」と考えると、一瞬 Union-Find が思い浮かぶかもしれません。

しかし Union-Find を用いると、下図のグラフに対して Yes と返してしまいます (正解は No です)。一旦戻るのは禁止ですが、 Union-Find でそれを禁止するのは困難です。

DP へ

今回のグラフは「左から右へと一方向に流れていくグラフ」です。こういうグラフの経路を扱う問題では DP が有効です。たとえば次のような DP テーブルを用意しましょう。


  • dpA[i]:上側の i 番目の頂点に到達できれば true、そうでなければ false
  • dpB[i]:下側の i 番目の頂点に到達できれば true、そうでなければ false

まず最初は左から 0 番目の頂点は上下ともに到達できる (開始時点で到達している) ので、

dpA[0] = dpB[0] = true;

とします。そして、 i 番目の頂点から  i+1 番目の頂点への遷移を考えます。

上側の  i 番目の頂点からの遷移

もし dpA[i] = true ならば、

  • 上側の  i 番目の頂点から、上側の  i+1 番目の頂点へ移動できるとき、dpA[i+1] = true と更新します
  • 上側の  i 番目の頂点から、下側の  i+1 番目の頂点へ移動できるとき、dpB[i+1] = true と更新します

dpA[i] = false のときは何もしません

下側の  i 番目の頂点からの遷移

もし dpB[i] = true ならば、

  • 下側の  i 番目の頂点から、上側の  i+1 番目の頂点へ移動できるとき、dpA[i+1] = true と更新します
  • 下側の  i 番目の頂点から、下側の  i+1 番目の頂点へ移動できるとき、dpB[i+1] = true と更新します

dpB[i] = false のときは何もしません

  そして最後に、dpA[N-1]dpB[N-1] のいずれかが true ならば "Yes"、そうでなければ "No" です。

 

コード

計算量は  O(N) となります。

#include <bits/stdc++.h>
using namespace std;

int main() {
    // 入力
    long long N, K;
    cin >> N >> K;
    vector<long long> A(N), B(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    for (int i = 0; i < N; ++i) cin >> B[i];

    // dp
    vector<bool> dpA(N, false), dpB(N, false);
    dpA[0] = dpB[0] = true;
    for (int w = 0; w < N-1; ++w) {
        // A の w 番目まで到達できるとき、さらに伸ばす
        if (dpA[w]) {
            if (abs(A[w] - A[w+1]) <= K) dpA[w+1] = true;
            if (abs(A[w] - B[w+1]) <= K) dpB[w+1] = true;
        }
        // B の w 番目まで到達できるとき、さらに伸ばす
        if (dpB[w]) {
            if (abs(B[w] - A[w+1]) <= K) dpA[w+1] = true;
            if (abs(B[w] - B[w+1]) <= K) dpB[w+1] = true;
        }
    }

    // A, B のどちらかの N-1 番目まで到達できれば Yes
    if (dpA[N-1] || dpB[N-1]) cout << "Yes" << endl;
    else cout << "No" << endl;
}