けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 265 F - Manhattan Cafe (黄色, 500 点)

 N 次元空間という、いかめしいものが出てくるけど、あまり関係ない。DP 高速化が本質。

問題概要

 N 次元空間上に 2 つの格子点  (p_{1}, \dots, p_{N}),  (q_{1}, \dots, q_{N}) が与えられる。

これらとのマンハッタン距離がともの  D 以下であるような格子点の個数を 998244353 で割ったあまりを求めよ。

制約

  •  1 \le N \le 100
  •  0 \le D, |p_{i}|, |q_{i}| \le 1000

考えたこと

0-indexed で考える。

 N 次元空間とか、マンハッタン距離とか聞いて、最初は幾何学的イメージを思い描いて解こうとしていた。

しかし冷静に考えると、幾何学的考察ではなくて、DP を導くところまでは比較的容易だった。今回の問題は要するに

  •  |x_{0} - p_{0}| + |x_{1} - p_{1}| + \dots + |x_{N-1} - p_{N-1}| \le D
  •  |x_{0} - q_{0}| + |x_{1} - q_{1}| + \dots + |x_{N-1} - q_{N-1}| \le D

をともに満たすような整数の組  (x_{0}, x_{1}, \dots, x_{N-1}) の個数を求めたいということだ。次の DP でできる。


dp[i][j][k] |x_{0}-p_{0}| + \dots + |x_{i-1} - p_{i-1}| = j かつ  |x_{0}-q_{0}| + \dots + |x_{i-1} - q_{i-1}| = k を満たすような整数の組  (x_{0}, x_{1}, \dots, x_{i-1}) の個数


これで計算量は  O(ND^{3}) となる。

高速化

このままだと TLE となるが、例によって、累積和を用いた DP 高速化で  O(ND^{2}) にできる。

まず簡単のため、 (p_{i}, q_{i}) に対して、鏡像変換や平行移動によって、 (0, |p_{i} - q_{i}|) としても答えが変わらないことに注意しよう。 A_{i} = |p_{i} - q_{i}| とおくことにする。

このとき、dp[i+1][a][b] の値を更新する部分を書き出すと次の様になる (添字が負になる部分は 0 と考える)。

dp[i+1][a][b] += dp[i][a-1][b-A[i]-1] + dp[i][a-2][b-A[i]-2] + ...
dp[i+1][a][b] += dp[i][a-A[i]-1][b-1] + dp[i][a-A[i]-2][b-2] + ...
dp[i+1][a][b] += dp[i][a][b-A[i]] + dp[i][a-1][b-A[i]+1] + ... + dp[i][a-A[i]][b]

これは、二次元配列を斜めに値を足していくような累積和によって効率よく計算できる。具体的には、配列 dp[i] に対して、次の累積和を用意することにした。

sum1[j][k+1] = dp[0][j] + dp[1][j-1] + ... + dp[k][j-k]
sum2[j][k+1] = dp[0][j] + dp[1][j+1] + ... + dp[k][j+k]
sum3[j][k+1] = dp[j][0] + dp[j+1][1] + ... + dp[j+k][k]

これによって、DP の更新が  O(1) の計算量でできるようになる。

全体の計算量も  O(ND^{2}) となって間に合う。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() noexcept : val(0) { }
    constexpr Fp(long long v) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const noexcept { return val; }
    constexpr int get_mod() const noexcept { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp &r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const noexcept {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const noexcept {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &r, long long n) noexcept {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD> &r) noexcept {
        return r.inv();
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;


int main() {
    int N, D;
    cin >> N >> D;
    vector<int> p(N), q(N), A(N);
    for (int i = 0; i < N; ++i) cin >> p[i];
    for (int i = 0; i < N; ++i) cin >> q[i];
    for (int i = 0; i < N; ++i) A[i] = abs(p[i] - q[i]);
    
    // DP
    vector<vector<mint>> dp(2100, vector<mint>(2100, 0));
    vector<vector<mint>> sum1, sum2, sum3;
    dp[0][0] = 1;
    for (int i = 0; i < N; ++i) {
        // ruiseki-wa
        sum1.assign(2100, vector<mint>(2100, 0));
        sum2.assign(2100, vector<mint>(2100, 0));
        sum3.assign(2100, vector<mint>(2100, 0));
        for (int j = 0; j <= D*2; ++j) {
            for (int k = 0; k <= D; ++k) {
                sum1[j][k+1] = sum1[j][k] + (j >= k ? dp[k][j-k] : 0);
            }
        }
        for (int j = 0; j <= D; ++j) {
            for (int k = 0; k <= D; ++k) {
                sum2[j][k+1] = sum2[j][k] + dp[k][j+k];
                sum3[j][k+1] = sum3[j][k] + dp[j+k][k];
            }
        }
         
        // dp
        vector<vector<mint>> nex(2100, vector<mint>(2100, 0));
        for (int a = 0; a <= D; ++a) {
            for (int b = 0; b <= D; ++b) {
                if (a+b >= A[i]) {
                    nex[a][b] += sum1[a+b-A[i]][a+1] - sum1[a+b-A[i]][max(a-A[i],0)];
                }
                if (b >= A[i]) {
                    if (a >= b-A[i]) nex[a][b] += sum3[a-b+A[i]][b-A[i]];
                    else nex[a][b] += sum2[b-A[i]-a][a];
                }
                if (a >= A[i]) {
                    if (b >= a-A[i]) nex[a][b] += sum2[b-a+A[i]][a-A[i]];
                    else nex[a][b] += sum3[a-A[i]-b][b];
                }
            }
        }
        swap(dp, nex);
    }
    mint res = 0;
    for (int a = 0; a <= D; ++a) {
        for (int b = 0; b <= D; ++b) {
            res += dp[a][b];
        }
    }
    cout << res << endl;
}