けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 025 B - RGB Coloring (青色, 700 点)

最初迷走したけど、順位表見ると上位陣が 5 分とかで解いていて、「いくらなんでもこの方針で 5 分はない、きっとなにか簡潔な視点があるはず」と思えて思い付けたのがよかった。

問題概要

N 個のマスを赤、緑、青、無の 4 色に塗り分ける。

A * (赤数) + (A + B) * (緑数) + B * (青数) = K

となるような塗り分け方が何通りあるかを 998244353 で割ったあまりで求めよ。

制約

  • 1 ≦ N ≦ 3×105
  • 1 ≦ A, B ≦ 3×105
  • 0 ≦ K ≦ 18×1010

解法

A の個数を a 個、B の個数を b 個として、

  • a <= N
  • b <= N
  • Aa + Bb = K

を満たすような (a, b) に対して、

  • N マスのうち A に対応する a 個を選ぶ (NCa 通り)
  • N マスのうち B に対応する b 個を選ぶ (NCb 通り)

としてあげて、

  • A のみに対応する部分: 赤色
  • B のみに対応する部分: 青色
  • A にも B にも対応する部分: 緑色

に塗ってあげればよい。したがって、各 (a, b) に対して NCa x NCb を合算する問題になる。

なお、nCk を求める方法は n と k の状況に応じて様々なものが考えられるが、以下のやり方が最もよく登場すると思われる:

nCk = \frac{n!}{k!(n-k)!} = (n!)(k!)^{-1}((n-k)!)^{-1}

であることを利用します。a!(a!)^{-1} のテーブルを予め作っておいて計算しています。

const int MAX = 510000;
const int MOD = 998244353;

long long fac[MAX], finv[MAX], inv[MAX];

// テーブルを作る前処理
void COMinit(){
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for(int i = 2; i < MAX; i++){
        fac[i] = fac[i-1] * i % MOD;
        inv[i] = MOD - inv[MOD%i] * (MOD/i) % MOD;
        finv[i] = finv[i-1] * inv[i] % MOD;
    }
}

// 二項係数計算
long long COM(int n,int k){
    if(n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n-k] % MOD) % MOD;
}

問題を解くコード

#include <iostream>
using namespace std;

const int MAX = 510000;
const int MOD = 998244353;

long long fac[MAX], finv[MAX], inv[MAX];


void COMinit(){
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for (int i = 2; i < MAX; i++){
        fac[i] = fac[i - 1] * i % MOD;
        inv[i] = MOD - inv[MOD%i] * (MOD / i) % MOD;
        finv[i] = finv[i - 1] * inv[i] % MOD;
    }
}
long long COM(int n, int k){
    if (n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD;
}

long long N, A, B, K;

int main() {
    COMinit(); // テーブルを作る

    cin >> N >> A >> B >> K;
    long long res = 0;
    for (long long a = 0; a <= N; ++a) { // A の個数を全探索
        long long rem = K - a * A;
        if (rem % B != 0) continue;
        long long b = rem / B;
        if (b > N) continue;
        long long tmp = COM(N, a) * COM(N, b) % MOD;
        res += tmp;
        res %= MOD;
    }

    cout << res << endl;
}