けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 104 D - Multiset Mean (黄色, 700 点)

すごく NTT したくなる

問題へのリンク

問題概要

正の整数  N, K が与えられる。 1, 2, \dots, N をそれぞれ  0 個以上  K 個以下とってくる方法のうち、平均が  x (= 1, 2, \dots, N) となるものの個数を素数  M で割ったあまりを、各  x に対して求めよ。

制約

  •  1 \le N, K \le 100

考えたこと

まず、平均制約を次のように言い換える。これをすることで「平均」に関する制約が「総和」に関する制約となって扱いやすくなる!

  • 整数  a, b, c, \dots の平均が  x である
  • 整数  a-x, b-x, c-x, \dots の総和が  0 である

これを踏まえて考えて行く。たとえば  N = 6 x = 3 のとき

  • -2 を  K 個以下
  • -1 を  K 個以下
  • 0 を  K 個以下
  • 1 を  K 個以下
  • 2 を  K 個以下
  • 3 を  K 個以下

ずつ選ぶ方法であって、総和が 0 になるものを数え上げる問題となる。これはさらに次のように考えることができる。

  • 0 は何個選んでもいいので  K+1 通り
  • マイナス部分から選んだ総和、プラス部分から選んだ総和が等しくなることが必要十分

よって、

  • dp[ v ][ s ] := 1, 2, ..., v を K 個以下ずつ選んでできる総和が s になる場合の数

というのを前処理しておけば、各 x に対する答えは次のように求められる。なお、s として考えるべき値のオーダーは  O(N^{2}K) となる。


 s = 1, 2, \dots に対して

res += dp[ x - 1 ][ s ] × dp[ N - x ][ s ]

として、答えは res × (K + 1) - 1 となる


前処理がちゃんとできていれば、各 x に対して答えを求める作業は  O(N^{3}K) でできる。

よって残りは dp[ v ][ s ] を求める問題へと帰着された。

 

解法 (1):DP 高速化 (いもす法 or 累積和)

DP 遷移式はナイーブには次のように考えられる。

dp[ v ][ s + v × i ] += dp[ v - 1 ][ s ] (i = 0, 1, ..., K)

しかしこれでは、DP テーブルの状態量が  O(N^{3}K) あって、遷移が  O(K) だけあるので、全体で  O(N^{3}K^{2}) の計算量となって間に合わない。

しかし上の更新式を見ると、次のことがわかる。


dp[ v - 1 ][ i ] から dp[ v ][ j ] への遷移は、i と j を v で割ったあまりが等しい部分でしか発生しない


このことから、配列 dp の第二添字を v でわったあまりごとに独立に考えてよい。そのとき、

dp[ v ][ s + v × i ] += dp[ v - 1 ][ s ] (i = 0, 1, ..., K)

という更新は「連続する区間に一律に値 dp[ v -1 ][ s ] を足している」というものとなっている。よっていもす法によって高速化できる (集める DP で書けば累積和で高速化できる)。

この高速化を行うことで前処理・クエリ処理の計算量はともに  O(N^{3}K) となって間に合う。

#include <bits/stdc++.h>
using namespace std;

// modint
vector<int> MODS = { 1000000007 }; // 実行時に決まる
template<int IND = 0> struct Fp {
    long long val;
    
    int MOD = MODS[IND];
    constexpr Fp(long long v = 0) noexcept : val(v % MODS[IND]) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<IND>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<IND>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<IND> modpow(const Fp<IND> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};
using mint = Fp<>;

void solve(int N, int K) {
    int S = 0;
    for (int i = 1; i <= N; ++i) S += i * K;
    vector<vector<mint>> dp(N+1, vector<mint>(S+1, 0));
    dp[0][0] = 1;
    for (int v = 1; v <= N; ++v) {
        for (int s = 0; s <= S; ++s) {
            dp[v][s] += dp[v-1][s];
            if (s+v*(K+1) <= S) dp[v][s+v*(K+1)] -= dp[v-1][s];
        }
        for (int s = 0; s+v <= S; ++s) dp[v][s+v] += dp[v][s];
    }
    for (int k = 1; k <= N; ++k) {
        mint res = 0;
        for (int s = 0; s <= S; ++s) res += dp[k-1][s] * dp[N-k][s];
        res *= (K+1);
        cout << res-1 << endl;
    }
}

int main() {
    int N, K;
    cin >> N >> K >> MODS[0];
    solve(N, K);
}

 

解法 (2):DP 高速化パートを形式的冪級数で導出

上では DP 高速化部分をいもす法でやった。しかしこのような DP 遷移を見ていると、次のような NTT をどうしてもやりたくなってしまう。

  •  (1 + x + x^{2} + \dots + x^{K})
  •  \times (1 + x^{2} + x^{4} + \dots + x^{2K})
  •  \times (1 + x^{3} + x^{6} + \dots + x^{3K})
  •  \times \dots

配列 dp[ v ] は、これを最初の v 個まで掛け合わせたときの各係数になる。よって NTT を用いれば  O(N^{3}K(\log N + \log K)) でできるのだが、僕のライブラリでは TLE した。ものすごく速い NTT なら間に合うのかもしれない。

しかし形式的冪級数を少し式変形すると十分扱えるものになる。

 1 + x^{v} + x^{2v} + \dots + x^{Kv} = \frac{1 - x^{(K+1)v}}{1 - x^{v}}

よって、一般に形式的冪級数に対して

  •  1 - x^{a} を掛ける
  •  1 - x^{a} で割る

という操作が高速にできるなら良さそう。

1 - xa を掛ける

形式的冪級数 1 - x^{a} を掛けるのは、DP 遷移 (配列 dp から配列 ndp への遷移) の言葉で翻訳すると次のような遷移になる

ndp[ i ] = dp[ i ] - dp[ i - a ]

この遷移は  O(s) でできる ( s = O(NK^{2}))。また、実装上は in-place に実装できる (i を降順に更新)。

1 - xa で割る

上の遷移の逆変換をとればよいので、次のようになる

dp[ i ] = ndp[ i ] + dp[ i - a ]

この遷移も  O(s) でできる。これも実装上は in-place に実装できる (i を昇順に更新)。

まとめ

以上から、DP パートはまとめて  O(N^{3}K) でできるので間に合う。そして上の DP は実は「いもす法による高速化」と完全に一致する。

 

解法 (3):前処理なしで、形式的冪級数で殴る

解法 (2) の形式的冪級数を用いた議論を突き詰めると、以下のような問題を直接殴ることもできる。

  •  1-x K 個以下
  •  2-x K 個以下
  • ...
  •  N-x K 個以下

選ぶ方法のうち、総和が 0 になる方法を直接数え上げる。まずは  x = 1 の場合の解を求めてあげる。そして  x を増やしていくとき、次のような更新を行えばよい。


  • 配列 dp に対して、「 N-x+1 K 個以下利用できる」という操作を削除する更新を行う (戻す DP)
  • 配列 dp に対して、「 1-x K 個以下利用できる」という操作を挿入する更新を行う

これを形式的冪級数の言葉で整理すると、次のようになる。

  •  v = N-x+1 として、 1 - x^{(K+1)v} で割る
  •  v = N-x+1 として、 1 - x^{v} をかける
  •  v = 1-x として、 1 - x^{(K+1)v} をかける
  •  v = 1-x として、 1 - x^{v} で割る

ただしこのままでは v < 0 になりうるので、次のように言い換える


  •  v = N-x+1 として、 1 - x^{(K+1)v} で割る
  •  v = N-x+1 として、 1 - x^{v} をかける
  •  v = x-1 として、 1 - x^{(K+1)v} をかける
  •  v = x-1 として、 1 - x^{v} で割る
  • 最後に  x^{-Kv} をかける

そして実装上は  x^{-Kv} をかけるのを省略して、代わりに dp 値を取得すべき添字が Kv だけ右にずれると考えれば OK。以上より、各差分更新を  O(N^{2}K) で実施できるので、全体の計算量は  O(N^{3}K) となる。

#include <bits/stdc++.h>
using namespace std;

// modint
vector<int> MODS = { 1000000007 }; // 実行時に決まる
template<int IND = 0> struct Fp {
    long long val;
    
    int MOD = MODS[IND];
    constexpr Fp(long long v = 0) noexcept : val(v % MODS[IND]) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<IND>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<IND>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<IND> modpow(const Fp<IND> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};
using mint = Fp<>;

void mul(vector<mint> &dp, int a) {
    for (int i = (int)dp.size()-1; i >= a; --i) dp[i] -= dp[i-a];
}

void div(vector<mint> &dp, int a) {
    for (int i = a; i < dp.size(); ++i) dp[i] += dp[i-a];
}

void solve(int N, int K) {
    int S = 0;
    for (int i = 1; i <= N; ++i) S += i * K;
    vector<mint> dp(S+1, 0);
    dp[0] = 1;

    for (int v = 1; v < N; ++v) {
        mul(dp, (K+1)*v);
        div(dp, v);
    }
    cout << dp[0]*(K+1)-1 << endl;
    int D = 0;
    for (int x = 2; x <= N; ++x) {
        div(dp, (K+1)*(N-x+1));
        mul(dp, N-x+1);
        mul(dp, (K+1)*(x-1));
        div(dp, x-1);
        D += (x-1)*K;
        cout << dp[D]*(K+1)-1 << endl;
    }
}

int main() {
    int N, K;
    cin >> N >> K >> MODS[0];
    solve(N, K);
}