けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 280 G - Do Use Hexagon Grid 2 (赤色, 600 点)

ハニカムにそんな性質があるなんて!!! 45 度回転する技術のアナロジーが炸裂する。

あと、x 座標が最小となる点で場合分けする際に、同一の x 座標を持つものに対してダブルカウントを除去する工夫が大変だった。

問題概要

2 次元平面上に  N 個の格子点が与えられる。各座標は  (x_{i}, y_{i}) で与えられる。

この平面上で 2 点間の距離を「座標を下図のようにハニカムで表したときの hop 数」で定義することにする。

この  N 点の中から 1 つ以上選ぶ方法のうち、選んだ点のうちどの 2 点間の距離も  D 以下になるようなものは何通りあるか、998244353 で割ったあまりを求めよ。

制約

  •  1 \le N \le 300

考えたこと

まず最初に、ハニカム距離ではなくマンハッタン距離だったらどう解けばよいかを考えた。

この場合定石として「45 度回す」テクニックがある。そうすると、ある点から距離  D 以内にある点の集合の形状が、各辺が x 軸 ・y 軸に平行な正方形になるのだ。

そうなれば解きやすい。

  • 選ぶ点の x 座標の最小値 xmin
  • 選ぶ点の y 座標の最小値 ymin

をそれぞれ固定して、xmin <= x[i] <= xmin + D、ymin <= y[i] <= ymin + D を満たす i の個数を数えて、そこから集合 {xmin, ymin} のサイズを引いた値を num として、2num を足していけばよい。

なお、ymin を固定したときの num の値を求める部分は、しゃくとり法を活用することで、ならし  O(1) で求められる。よって全体として  O(N^{2}) の計算量で解ける。

この問題の場合

ここから先は分からなかったので解説を読んだ。ハニカム距離の場合にも「45 度回す」に相当するテクニックが存在することを知って感動した!!

2 点 (0, 0), (x, y) のハニカム距離 d は、3 次元空間上の 2 点 (0, 0, 0), (x, y, x - y) のチェビシェフ距離に一致する。つまり、

d = max(|x|, |y|, |x - y|)

となる。よって、上述の解法を 3 次元に拡張することで、 O(N^{3}) で解ける。

ダブルカウントを防ぐ

厄介に感じたのは、x 座標の最小値、y 座標の最小値、z 座標の最小値でそれぞれ場合分けするときに、タイが発生しうること。タイが発生すると、工夫なしでは、ダブルカウントが発生してしまう。

そこでより正確には、

  • (x 座標の値, 点の index) の最小値
  • (y 座標の値, 点の index) の最小値
  • (z 座標の値, 点の index) の最小値

で場合分けすることにした。これでダブルカウントしなくなる。

コード

#include <bits/stdc++.h>
using namespace std;
using pll = pair<long long, long long>;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    // 入力
    long long N, D;
    cin >> N >> D;
    vector<long long> X(N), Y(N), Z(N);
    for (int i = 0; i < N; ++i) {
        cin >> X[i] >> Y[i];
        Z[i] = X[i] - Y[i];
    }

    // 2^k mod MOD
    vector<mint> tpow(N + 1, 1);
    for (int i = 0; i < N; ++i) tpow[i+1] = tpow[i] * 2;

    // X 座標 (最小)、Y 座標 (最小) を固定して Y - X で尺取り法
    mint res = 0;
    for (int x = 0; x < N; ++x) {
        for (int y = 0; y < N; ++y) {
            // X[x], Y[y] はそれぞれ最小でなければならない (同じ値の場合も順序付け)
            if (pll(X[y], y) < pll(X[x], x) || X[y] > X[x] + D) continue;
            if (pll(Y[x], x) < pll(Y[y], y) || Y[x] > Y[y] + D) continue;

            // X, Y 的に条件を満たす id のみを抽出して、Y - X の小さい順に
            vector<int> ids;
            for (int k = 0; k < N; ++k) {
                if (pll(X[k], k) < pll(X[x], x) || X[k] > X[x] + D) continue;
                if (pll(Y[k], k) < pll(Y[y], y) || Y[k] > Y[y] + D) continue;
                ids.push_back(k);
            }
            sort(ids.begin(), ids.end(), [&](int a, int b) { 
                return pll(Z[a], a) < pll(Z[b], b);
            });

            // 尺取り法
            int r = 0;
            for (int l = 0; l < ids.size(); ++l) {
                // x, y が条件を満たさない場合はダメ
                if (pll(Z[x], x) < pll(Z[ids[l]], ids[l]) || Z[x] > Z[ids[l]] + D) continue;
                if (pll(Z[y], y) < pll(Z[ids[l]], ids[l]) || Z[y] > Z[ids[l]] + D) continue;

                // 右端を求める
                while (r < ids.size() && Z[ids[r]] <= Z[ids[l]] + D) ++r;
                int num = r - l - set<int>({x, y, ids[l]}).size();
                res += tpow[num];
            }
        }
    }
    cout << res << endl;
}