けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 096 E - Everything on It (900 点)

部分点がなければ CE 2 完でも赤パフォ出たのに...
それはともかく、この手の包除で絶対に解けるという安定感をもって解けるようになりたい!

問題へのリンク

問題概要

ラーメンに  N 種類のトッピングを自由に組み合わせて乗せることができます。トッピングの組合せは  2^{N} 通りあります。

この  2^{N} 通りから何個か選ぶ  2^{2^{N}} 通りの方法のうち、どのトッピングについても 2 個以上含まれているようなものの場合の数を素数  M で割ったあまりを求めよ。

制約

考えたこと

一目包除原理ではある。なぜなら、条件は

  • トッピング 0 について、2 個以上含まれている
  • トッピング 1 について、2 個以上含まれている
  • ...
  • トッピング  N-1 について、2 個以上含まれている

という「かつ」の形で書かれているものの、その一つ一つが大変扱いづらい。こういうときは余事象をとると、

  • トッピング 0 について 1 個以下しかない、または
  • トッピング 1 について 1 個以下しかない、または
  • トッピング  N-1 について 1 個以下しかない

という「または」で表された形になる。「または」になるのは確かに扱いづらくなるのだが、その代わり一つ一つの条件は考えやすくなった。こういうときは包除原理がジャストフィットする。

包除原理で求めたいもの

包除原理とはざっくりと言えば

(条件から 0 個選んで、それらを満たさないようにする場合の数) (これはつまり「全体集合」)
- (条件から 1 個選んで、それらを満たさないようにする場合の数)
+ (条件から 2 個選んで、それらを満たさないようにする場合の数)
- (条件から 3 個選んで、それらを満たさないようにする場合の数)
+ (条件から 4 個選んで、それらを満たさないようにする場合の数)
- ...

を計算したいというもの。今回の問題で言えば、 N 種類あるトッピングのうちの  n 個を選んで、それらが 1 個以下しか含まれないようにする場合の数を求めれば良いということになる。その  n 個のトッピングを選ぶ方法は  C(N, n) 通りであり、これを最後に掛け算する。

また、その  n 個のトッピングのいずれかを含む組合せがちょうど  k 種類であるようなものを数え上げることにする。

このときこの  k 種類については、残りの  N-n 個のトッピングについてはなんでもよくて、その分の  (2^{N-n})^{k} = 2^{k(N-n)} 通りを掛ける。

さらにその  n 個のトッピングのいずれも乗っていないような組合せは  2^{N-n} 通りあって、そのそれぞれについて採用するかしないかは自由なので  2^{2^{N-n}} 通りを掛ける。

帰着された問題

結局われわれは

  •  n 個のトッピングの組合せの中から  k 種類を選ぶ
  • ただしどのトッピングも高々 1 回しか選ばれないようにする

という方法の数を求めたいといことになった。これを二通りの方法で求める。

解法 1: スターリング数を組み合わせる

場合分けする

  • どのトッピングについても 1 回ずつ乗せる場合
  • どれかのトッピングについては一度も使用しない場合

前者はスターリング数そのもので、 S(n, k) 通り。
後者は実は  n 個のものを  k+1 個のグループに分けた上でどれか 1 個のグループを特別扱いして削除する方法の数に等しいので、 (k+1)S(n, k+1) 通り。

これらを合計すればよい。

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;


// modint
vector<int> MODS = { 1000000007 }; // 実行時に決まる
template<int IND = 0> struct Fp {
    long long val;
    
    int MOD = MODS[IND];
    constexpr Fp(long long v = 0) noexcept : val(v % MODS[IND]) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<IND>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<IND>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<IND> modpow(const Fp<IND> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};


// 二項係数ライブラリ
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};


// スターリング数 (n 個を k グループにわける、n >= k)
template<class T> struct Stirling {
    vector<vector<T> > S;
    constexpr Stirling(int MAX) noexcept : S(MAX, vector<T>(MAX, 0)) {
        S[0][0] = 1;
        for (int n = 1; n < MAX; ++n) {
            for (int k = 1; k <= n; ++k) {
                S[n][k] = S[n-1][k-1] + S[n-1][k] * k;
            }
        }
    }
    constexpr T get(int n, int k) {
        if (n < 0 || k < 0 || n < k) return 0;
        return S[n][k];
    }
};



const int MAX = 3100;
int main() {
    // 入力
    long long N;
    cin >> N >> MODS[0];
    using mint = Fp<>;
    
    // 前計算
    BiCoef<mint> bc(MAX); // 二項係数計算の前処理
    Stirling<mint> sl(MAX); // スターリング数の前処理

    // 2^n や 2^2^n の前計算、2^2^(n+1) = (2^2^n)^2
    vector<mint> two(MAX*MAX, 0), dtwo(MAX, 0);
    two[0] = 1, dtwo[0] = 2;
    for (int i = 1; i < MAX; ++i) dtwo[i] = dtwo[i-1] * dtwo[i-1];
    for (int i = 1; i < MAX*MAX; ++i) two[i] = two[i-1] * 2;
    
    // 求める
    mint res = 0;
    for (int n = 0; n <= N; ++n) {
        mint add = 0;
        for (int k = 0; k <= n; ++k) {
            mint jiyudo = two[(N-n)*k] * dtwo[N-n];
            mint core = sl.get(n, k) + sl.get(n, k+1) * (k+1);
            add += core * jiyudo;
        }
        mint choose = bc.com(N, n);
        add *= choose;
        if (n % 2 == 0) res += add;
        else res -= add;
    }
    cout << res << endl;
}

解法 2: スターリング数を真似した DP

求めたいのは以下を満たす場合の数であった。

  •  n 個のトッピングの組合せの中から  k 種類を選ぶ
  • ただしどのトッピングも高々 1 回しか選ばれないようにする

これをスターリング数を求める DP と同じような DP で求めることを考える。

  • 要素 1 が使われないとき: f(n-1, k) 通り
  • 要素 1 が単独で使われるとき:  f(n-1, k-1) 通り
  • 要素 1 が単独でなく使われるとき: f(n-1, k) × k 通り (残りの  n-1 個を  k グループにわけて、1 をどこに入れるかが  k 通り)

となるので

  •  f(n, k) = f(n-1, k) + f(n-1, k-1) + k f(n-1, k)

となることがわかった。

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;


// modint
vector<int> MODS = { 1000000007 }; // 実行時に決まる
template<int IND = 0> struct Fp {
    long long val;
    
    int MOD = MODS[IND];
    constexpr Fp(long long v = 0) noexcept : val(v % MODS[IND]) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<IND>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<IND>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<IND> modpow(const Fp<IND> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};


// 二項係数ライブラリ
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};


// スターリング数 (n 個を k グループにわける、n >= k)
template<class T> struct Stirling {
    vector<vector<T> > S;
    constexpr Stirling(int MAX) noexcept : S(MAX, vector<T>(MAX, 0)) {
        S[0][0] = 1;
        for (int n = 1; n < MAX; ++n) {
            for (int k = 1; k <= n; ++k) {
                S[n][k] = S[n-1][k-1] + S[n-1][k] * k;
            }
        }
    }
    constexpr T get(int n, int k) {
        if (n < 0 || k < 0 || n < k) return 0;
        return S[n][k];
    }
};



const int MAX = 3100;
int main() {
    // 入力
    long long N;
    cin >> N >> MODS[0];
    using mint = Fp<>;
    
    // 前計算
    BiCoef<mint> bc(MAX); // 二項係数計算の前処理
    vector<vector<mint> > core(MAX, vector<mint>(MAX, 0));
    core[0][0] = 1;
    for (int n = 1; n < MAX; ++n) {
        for (int k = 0; k <= n; ++k) {
            core[n][k] += core[n-1][k];
            if (k-1 >= 0) core[n][k] += core[n-1][k-1];
            core[n][k] += core[n-1][k] * k;
        }
    }

    // 2^n や 2^2^n の前計算、2^2^(n+1) = (2^2^n)^2
    vector<mint> two(MAX*MAX, 0), dtwo(MAX, 0);
    two[0] = 1, dtwo[0] = 2;
    for (int i = 1; i < MAX; ++i) dtwo[i] = dtwo[i-1] * dtwo[i-1];
    for (int i = 1; i < MAX*MAX; ++i) two[i] = two[i-1] * 2;
    
    // 求める
    mint res = 0;
    for (int n = 0; n <= N; ++n) {
        mint add = 0;
        for (int k = 0; k <= n; ++k) {
            mint jiyudo = two[(N-n)*k] * dtwo[N-n];
            add += core[n][k] * jiyudo;
        }
        mint choose = bc.com(N, n);
        add *= choose;
        if (n % 2 == 0) res += add;
        else res -= add;
    }
    cout << res << endl;
}