けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

HHKB プログラミングコンテスト 2020 D - Squares (青色, 400 点)

これ、「重なるものを数える」という風に考えれば、縦方向と横方向を独立に考えれば良いことに気付けるかが結構ポイントっぽい

問題へのリンク

問題概要

整数  N, A, B が与えられます。

辺の長さが  N の白い正方形を座標平面の  (0,0),(N,0),(0,N),(N,N) に 4 頂点が重なるように置きます。

次に、この白い正方形の内部または周上に収まるように、辺の長さが  A の青い正方形と辺の長さが  B の赤い正方形を 1 つずつ置きます。

ただし、正方形のどの辺も x 軸または y 軸と平行に置かれている必要があります。

また、青い正方形と赤い正方形の各頂点は格子点上に置かれている必要があります。

赤い正方形の内部と青い正方形の内部が重ならないように置く方法の数を 1,000,000,007 で割ったあまりを求めてください。

1 つの入力につき、 T 個のテストケースに答えてください。

制約

  •  1 \le T \le 10^{5}

考えたこと

いわゆる  O(1) な数学ゲー。比較的人気がないタイプの問題かもしれないけど、高難易度になるとこういうのが部分的に要求されることがよくある。なので、400 点問題の段階からこういうのを練習するのはとてもいい感じ。

で、「長方形が重ならない」という条件は扱いづらいので、「長方形が重なる」ものを数えることにする。これは

  • 縦方向に見たときに 2 つの正方形の一辺が重なる
  • 横方向に見たときに 2 つの正方形の一辺が重なる

という条件をともに満たすことと同値。よって縦横独立に考えることができる。よって、以下の問題に帰着された。


長さ  N の区間に、長さ  A, B の棒を 2 本置く (両端点が格子点)。交差するものは何通りあるか


この答えを  X としたとき、答えは

 (N - A + 1)^{2}(N - B + 1)^{2} - X^{2}

と求められる。

帰着された問題

一次元の問題に帰着されたが、ここからは色んな方法がありそう。一つの方法として「もう一度補集合を考える」というのがある。そうすると、「2 つの棒が重ならないように配置する方法」を数えることになる。このとき、全体区間は

  • (空白)(長さ A の棒)(空白)(長さ B の棒)(空白)
  • (空白)(長さ A の棒)(空白)(長さ B の棒)(空白)

という状態になる (空白部分の長さは 0 でもよい)。一方、空白部分の長さの合計は  N - A - B なので、次の数え上げ問題となる。

  •  X_{1} + X_{2} + X_{3} = N - A - B
  •  X_{i} \ge 0

これを満たす整数  (X_{1}, X_{2}, X_{3}) の組を数えればよい (それを 2 倍する)。これは「重複組合せ」と呼ばれているもので、

 C(N - A - B + 2, 2) = \frac{1}{2}(N - A - B + 2)(N - A - B + 1) 通り

となる。以上から、

 X = (N - A + 1)(N - B + 1) -  (N - A - B + 2)(N - A - B + 1)

となることがわかった。

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

int main() {
    int T;
    cin >> T;
    while (T--) {
        long long N, A, B;
        cin >> N >> A >> B;

        if (N < A + B) {
            cout << 0 << endl;
            continue;
        }

        mint all = mint(N-A+1) * (N-A+1) * (N-B+1) * (N-B+1);
        mint X = mint(N-A+1) * (N-B+1) - mint(N-A-B+2) * (N-A-B+1);
        cout << all - X*X << endl;
    }
}