けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder Library Practice Contest F - Convolution

本当に ACL の convolution をそのまま試してほしいという問題ですね。

問題概要

整数列  a_{0}, a_{1}, \dots, a_{N-1} と、整数列  b_{0}, b_{1}, \dots, b_{M-1} が与えられる。

 c_{i} = \displaystyle \sum_{j=0}^{i} a_{i} b_{i-j} \mod 998244353

によって定義される整数列  c_{0}, c_{1}, \dots, c_{N+M-1} を求めよ。

制約

  •  1 \le N, M \le 524288
  •  0 \le a_{i}, b_{i} \lt 998244353

解法

とにかく、ACL のドキュメントにそのままの式が書いてある!

https://atcoder.github.io/ac-library/production/document_ja/convolution.html

次の関数が使える。

vector<T> convolution<int m = 998244353>(vector<T> a, vector<T> b)

この関数が使えるための条件のうち、ぱっと見ではわかりにくい

 2^{c} | (m - 1) かつ  |a| + |b| - 1 \le 2^{c} なる整数  c が存在する」

を確かめておく。

まず、 m = 998244353 はとてもいい素数で、

 m-1 = 2^{23} \times 119

を満たす。よって、

 |a| + |b| - 1 \lt 524288 + 524288 = 2^{19} + 2^{19} = 2^{20}

であることから、 c = 20 が条件を満たす ( c = 21, 22, 23 も OK)。

コード

関数 convolution<int m = 998244353> は、デフォルトで 998244353 を法としているので、次のようなコードで OK。

#include <bits/stdc++.h>
#include <atcoder/convolution>
using namespace std;
using namespace atcoder;

int main() {
    int N, M;
    cin >> N >> M;
    vector<long long> A(N), B(M);
    for (int i = 0; i < N; ++i) cin >> A[i];
    for (int i = 0; i < M; ++i) cin >> B[i];
    
    vector<long long> C = convolution(A, B);
    for (int i = 0; i < N + M - 1; ++i) cout << C[i] << " ";
    cout << endl;
}

modint を使っても OK

さらに、ACL では次の関数も定義されている。

vector<static_modint<m>> convolution<int m>(vector<static_modint<m>> a, vector<static_modint<m>> b)

つまり、次のコードのように、配列  A, B, C の要素の型として long long 型ではなく mint 型を使っても OK。

#include <bits/stdc++.h>
#include <atcoder/convolution>
#include <atcoder/modint>
using namespace std;
using namespace atcoder;

using mint = modint998244353;

int main() {
    int N, M;
    cin >> N >> M;
    vector<mint> A(N), B(M);
    for (int i = 0; i < N; ++i) {
        long long a;
        cin >> a;
        A[i] = a;
    }
    for (int i = 0; i < M; ++i) {
        long long b;
        cin >> b;
        B[i] = b;
    }
    
    vector<mint> C = convolution(A, B);
    for (int i = 0; i < N + M - 1; ++i) cout << C[i].val() << " ";
    cout << endl;
}