FFT の勉強シリーズその 1 なん
yukicoder 0206 数の積集合を求めるクエリ
問題概要
どの 2 つも互いに相異なるサイズ L の数列 A[0], A[1], ..., A[L-1] と
どの 2 つも互いに相異なるサイズ M の数列 B[0], B[1], ..., B[M-1] とが与えられる。
各 A[i], B[j] の値は 1 以上 N 以下の整数である。
q = 0, 1, 2, ..., Q-1 に対して、
との共通部分集合のサイズを出力せよ。
制約
- 1 ≤ Q ≤ N ≤ 105
解法
想定解法ではないみたいだが、FFT がピッタリはまる。
数列 A, B をそれぞれ別の捉え方をして
- A: ベクトル ( ) ( は数列 A に i+1 があるかどうか)
- B: ベクトル ( ) ( は数列 B に i+1 があるかどうか)
としてみる。A = {1, 3, 4} のときはベクトルは (1, 0, 1, 1, 0, ...) になる。A と B はともにサイズ のベクトルである。こうすると、 のときというのはこれらのベクトルの内積になる。 の場合は B の各要素が右にスライドしたものと内積をとる感じになる。
さて、こういうスライドしていくベクトル同士の内積となると FFT を思い浮かべるのは恐らく典型だと推測される。多項式にしてみる:
- A:
- B:
こうすると、 のときは、多項式 A と B の次数の「差」が のところの係数を掛け算したものを合計したものになる。このままだと畳み込みっぽくない。そこで B の係数をひっくり返してみる
- B:
こうすると、 のときは、多項式 A と B の次数の「和」が のところの係数を掛け算したものを合計したものになる。これはすなわち、A と B を多項式として掛け算したものの次数を見ればいいことになる。
ここまで来れば A と B を多項式として掛け算するのを FFT によって で計算できる。
#include <iostream> #include <vector> #include <cmath> using namespace std; struct ComplexNumber { double real, imag; inline ComplexNumber& operator = (const ComplexNumber &c) {real = c.real; imag = c.imag; return *this;} friend inline ostream& operator << (ostream &s, const ComplexNumber &c) {return s<<'<'<<c.real<<','<<c.imag<<'>';} }; inline ComplexNumber operator + (const ComplexNumber &x, const ComplexNumber &y) { return {x.real + y.real, x.imag + y.imag}; } inline ComplexNumber operator - (const ComplexNumber &x, const ComplexNumber &y) { return {x.real - y.real, x.imag - y.imag}; } inline ComplexNumber operator * (const ComplexNumber &x, const ComplexNumber &y) { return {x.real * y.real - x.imag * y.imag, x.real * y.imag + x.imag * y.real}; } inline ComplexNumber operator * (const ComplexNumber &x, double a) { return {x.real * a, x.imag * a}; } inline ComplexNumber operator / (const ComplexNumber &x, double a) { return {x.real / a, x.imag / a}; } struct FFT { static const int MAX = 1<<18; // must be 2^n ComplexNumber AT[MAX], BT[MAX], CT[MAX]; void DTM(ComplexNumber F[], bool inv) { int N = MAX; for (int t = N; t >= 2; t >>= 1) { double ang = acos(-1.0)*2/t; for (int i = 0; i < t/2; i++) { ComplexNumber w = {cos(ang*i), sin(ang*i)}; if (inv) w.imag = -w.imag; for (int j = i; j < N; j += t) { ComplexNumber f1 = F[j] + F[j+t/2]; ComplexNumber f2 = (F[j] - F[j+t/2]) * w; F[j] = f1; F[j+t/2] = f2; } } } for (int i = 1, j = 0; i < N; i++) { for (int k = N >> 1; k > (j ^= k); k >>= 1); if (i < j) swap(F[i], F[j]); } } // C is A*B void mult(long long A[], long long B[], long long C[]) { for (int i = 0; i < MAX; ++i) AT[i] = {(double)A[i], 0.0}; for (int i = 0; i < MAX; ++i) BT[i] = {(double)B[i], 0.0}; DTM(AT, false); DTM(BT, false); for (int i = 0; i < MAX; ++i) CT[i] = AT[i] * BT[i]; DTM(CT, true); for (int i = 0; i < MAX; ++i) { CT[i] = CT[i] / MAX; C[i] = (long long)(CT[i].real + 0.5); } } }; int main() { int L, M, N, Q; cin >> L >> M >> N; long long A[FFT::MAX], B[FFT::MAX], C[FFT::MAX]; for (int i = 0; i < L; ++i) { int a; cin >> a; A[a-1] = 1; } for (int i = 0; i < M; ++i) { int b; cin >> b; B[N-b] = 1; } FFT f; f.mult(A, B, C); cin >> Q; for (int i = 0; i < Q; ++i) { cout << C[N-1+i] << endl; } }