けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 107 D - Number of Multisets (黄色, 600 点)

いろんな DP が考えられそう!

問題概要

正整数  N,K が与えらる。以下の条件を全て満たす有理数の多重集合は何種類存在するか、998244353 で割ったあまりを求めよ。

  • 多重集合の要素数 N
  • 多重集合の要素の総和は  K
  • 多重集合の要素は全て  \frac{1}{2^{i}} ( i は非負整数) の形

制約

  •  1 \le K \le N \le 3000

解法 (1):僕の考えた解法

想定解法とは違う解法でやった。「条件を満たす対象を一意に定めるような操作列を考えて、それを数え上げる」というのも一つの典型だと思う。

まず、今回の数列は次のような操作によって一意に定められることに注意する。

  • まず 1 のうちのいくつかを分裂させる
    • ex: (1, 1, 1) -> (1, 1/2, 1/2, 1/2, 1/2)
  • 1/2 のうちのいくつかを分裂させる
    • ex: (1, 1/2, 1/2, 1/2, 1/2) -> (1, 1/2, 1/2, 1/2, 1/4, 1/4)
  • 1/4 のうちのいくつかを分裂させる
    • ex: (1, 1/2, 1/2, 1/2, 1/4, 1/4) -> (1, 1/2, 1/2, 1/2, 1/8, 1/8, 1/8, 1/8)
  • 1/8 のうちのいくつかを分裂させる
    • ex: (1, 1/2, 1/2, 1/2, 1/8, 1/8, 1/8, 1/8) -> ((1, 1/2, 1/2, 1/2, 1/8, 1/8, 1/16, 1/16, 1/16, 1/16)
  • ...

この操作において、

  • 分裂させた対象の個数が  N-K
  • 分裂させる対象が単調減少

というところがポイントで、それによって、「操作列」と「生成多重集合」とが一対一に対応する!!!よって操作列の方を数え上げればよい。

DP へ

  • dp[ i ][ j ] := 合計 i 回の分裂操作を行った結果、結果得られた生成多重集合の最小の値の個数が j 個になるような操作列の個数

とする。このとき DP 遷移は、j を偶数として、k = j/2, j/2+1, ... に対して

dp[ i ][ j ] += dp[ i - j/2 ][ k ]

と表せる。このままでは  O(N^{3}) の計算量となる。しかしながら、累積和を用いて高速化することで  O(N^{2}) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N, K;
    cin >> N >> K;
    vector<vector<mint>> dp(N-K+1, vector<mint>(N+1, 0));
    vector<vector<mint>> sdp(N+K+1, vector<mint>(N+2, 0));

    auto accum = [&](int i) -> void {
        for (int j = 0; j <= N; ++j) sdp[i][j+1] = sdp[i][j] + dp[i][j];
    };
    dp[0][K] = 1;
    accum(0);
    for (int i = 1; i <= N-K; ++i) {
        for (int j = 2; i-j/2 >= 0 && j <= N; j += 2) {
            dp[i][j] = sdp[i-j/2].back() - sdp[i-j/2][j/2];
        }
        accum(i);
    }
    cout << sdp[N-K].back() << endl;
}

解法 (2):賢い DP - 想定解法

もっと賢い DP でも解ける!!これ賢い!!!

求める答えを  f(N, K) とおいたとき、これを求めるために「1 を使うか使わないか」で場合分けをする。かなり賢い!!!!

1 を使うとき

その「1」を除去した問題を再度考えることができる。それは

「1, 1/2, ... を使って、 N-1 個で  K-1 を作れ」

という問題となるので、 f(N-1, K-1) 通りになる。

1 を使わないとき

その場合は

「1/2, 1/4, ... を使って、 N 個で  K を作れ」

という問題となる。これは全体を 2 倍すると、

「1, 1/2, ... を使って、 N 個で  2K を作れ

という問題になることがわかる。よって  f(N, 2K) 通りとなる。

 
まとめると、

 f(N, K) = f(N-1, K-1) + f(N, 2K)

となる。これは  O(N^{2}) で解ける。ここで、 N \lt K の場合は  f(N, K) = 0 となることに注意しておく。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

mint func(int N, int K, vector<vector<mint>> &dp) {
    if (N == 0) {
        if (K == 0) return mint(1);
        else return mint(0);
    }
    if (N < 0 || K <= 0 || K > N) return mint(0);
    if (dp[N][K] != mint(-1)) return dp[N][K];
    return dp[N][K] = func(N-1, K-1, dp) + func(N, K*2, dp);
}

int main() {
    int N, K;
    cin >> N >> K;
    vector<vector<mint>> dp(N+1, vector<mint>(N+1, mint(-1)));
    cout << func(N, K, dp) << endl;
}