けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Educational Codeforces Round 057 G - Lucky Tickets (R2400)

NTT と聞いて

問題へのリンク

問題概要

偶数  n が与えられる。
十進法表記で  d_1, d_2, \dots, d_k ( k \le 10) しか登場しない  n 桁の整数のうち、

  • 前半  \frac{n}{2} 桁の各位の和
  • 後半  \frac{n}{2} 桁の各位の和

が等しいものが何通りあるか、998244353 で割ったあまりで答えよ。leading zero は OK。

制約

  •  2 \le n \le 2 × 10^{5}

考えたこと

「等しい和」ごとに数え上げて二乗していけばよい。

  • dp[ m ][ s ] := m 個の数 (d1, d2, ..., dk のみ) の和が s となるような場合の数

とすると自然にナップサック DP だが、このままだと  O(n^{2}) なので計算量を減らす必要がある。DP の高速化はとりあえず式を書いて式とにらめっこするのがいい気がする、たぶん。

というわけで DP の遷移式を書いてみる:

  • dp[ m+1 ][ s ] += dp[ m ][ s ]
  • dp[ m+1 ][ s ] += dp[ m ][ s - d_k ] for each k

これを見ていると、

  • 遷移式が m に依存しない
  • 遷移式が s にも依存しない (s と s - d_k との差が d_k で一定)

m にも依存しないというあたりは、「行列累乗」とか「FFT」っぽさがある。今回は、dp[ m ] を多項式にして考えてみることにする。例えば、d = (0, 2, 4, 5) のとき、

  • dp[ 1 ] = 1 + x2 + x4 + x5 (つまり、0, 2, 4, 5 がそれぞれ 1 通りずつ)
  • dp[ 2 ] = 1 + 2x2 + 3x4 + 2x5 + 2x6 + 2x7 + x8 + 2x9 + x10

になる。こうして見ると、dp[ 2 ] は多項式として、dp[ 1 ] × dp[ 1 ] になっていることがわかる。同様に dp[ m ] は dp[ 1 ] の m 乗である。

というわけで多項式の乗算になっているので NTT がバッチリはまる。

NTT

dp[ 1 ] のなす多項式 f(x) とする。多項式としての累乗  f(x)^{\frac{n}{2}} を求めたい。手順としては

  •  f(x) を NTT して、多項式  g(x) を得る
  • NTT した世界では、多項式の積は係数同士をそのままかけ算すればいい。各係数を  \frac{n}{2} 乗する (ここで二分累乗法を用いる)。こうして得られた多項式 h(x) とする
  • 最後に  h(x) を逆 NTT して、答え多項式  r(x) を得る

そして最後に、多項式  r(x) i 次の係数は、各位の和が  i となるような  \frac{n}{2} 桁の整数の個数を表すようになる。

よって、各係数を二乗して足してあげればいい。計算量は、

  • 多項式の次数は  O(kN) = O(N)
  • NTT 計算、その逆計算に、 O(N\log{N})
  • 各係数ごとに二分累乗するのに  O(\log{N})、全部で  O(N\log{N})

よって全体として、 O(N\log{N}) となる。なお、NTT を行うためのライブラリは、ここに公開した。

#include <iostream>
#include <vector>
using namespace std;

long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

long long modinv(long long a, long long mod) {
    long long b = mod, u = 1, v = 0;
    while (b) {
        long long t = a/b;
        a -= t*b; swap(a, b);
        u -= t*v; swap(u, v);
    }
    u %= mod;
    if (u < 0) u += mod;
    return u;
}

namespace NTT {
    const int MOD = 998244353;  // to be set appropriately
    const long long PR = 3;     // to be set appropriately
    
    void trans(vector<long long> &v, bool inv = false) {
        int n = (int)v.size();
        for (int i = 0, j = 1; j < n-1; j++) {
            for (int k = n>>1; k > (i ^= k); k >>= 1);
            if (i > j) swap(v[i], v[j]);
        }
        for (int t = 2; t <= n; t <<= 1) {
            long long bw = modpow(PR, (MOD-1)/t, MOD);
            if (inv) bw = modinv(bw, MOD);
            for (int i = 0; i < n; i += t) {
                long long w = 1;
                for (int j = 0; j < t/2; ++j) {
                    int j1 = i + j, j2 = i + j + t/2;
                    long long c1 = v[j1], c2 = v[j2] * w % MOD;
                    v[j1] = c1 + c2;
                    v[j2] = c1 - c2 + MOD;
                    while (v[j1] >= MOD) v[j1] -= MOD;
                    while (v[j2] >= MOD) v[j2] -= MOD;
                    w = w * bw % MOD;
                }
            }
        }
        if (inv) {
            long long inv_n = modinv(n, MOD);
            for (int i = 0; i < n; ++i) v[i] = v[i] * inv_n % MOD;
        }
    }
    
    // C is A*B
    vector<long long> mult(vector<long long> A, vector<long long> B) {
        int size_a = 1; while (size_a < A.size()) size_a <<= 1;
        int size_b = 1; while (size_b < B.size()) size_b <<= 1;
        int size_fft = max(size_a, size_b) << 1;
        
        vector<long long> cA(size_fft, 0), cB(size_fft, 0), cC(size_fft, 0);
        for (int i = 0; i < A.size(); ++i) cA[i] = A[i];
        for (int i = 0; i < B.size(); ++i) cB[i] = B[i];
        
        trans(cA); trans(cB);
        for (int i = 0; i < size_fft; ++i) cC[i] = cA[i] * cB[i] % MOD;
        trans(cC, true);
        
        vector<long long> res((int)A.size() + (int)B.size() - 1);
        for (int i = 0; i < res.size(); ++i) res[i] = cC[i];
        return res;
    }
};



int main() {
    int n, k; cin >> n >> k;
    n /= 2;
    vector<long long> v(1<<22, 0);
    for (int i = 0; i < k; ++i) {
        int d; cin >> d;
        v[d] = 1;
    }
    NTT::trans(v);
    vector<long long> each(1<<22, 1);
    for (int i = 0; i < each.size(); ++i) {
        int N = n;
        while (N > 0) {
            if (N & 1) each[i] = each[i] * v[i] % NTT::MOD;
            v[i] = v[i] * v[i] % NTT::MOD;
            N >>= 1;
        }
    }
    NTT::trans(each, true);
    
    long long res = 0;
    for (int i = 0; i < each.size(); ++i) {
        res += each[i] * each[i] % NTT::MOD;
    }
    res %= NTT::MOD;
    cout << res << endl;
}