けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

yukicoder 0206 数の積集合を求めるクエリ

FFT の勉強シリーズその 1 なん

yukicoder 0206 数の積集合を求めるクエリ

問題概要

どの 2 つも互いに相異なるサイズ L の数列 A[0], A[1], ..., A[L-1] と
どの 2 つも互いに相異なるサイズ M の数列 B[0], B[1], ..., B[M-1] とが与えられる。
各 A[i], B[j] の値は 1 以上 N 以下の整数である。

q = 0, 1, 2, ..., Q-1 に対して、

  • 集合 {A[0], A[1], ..., A[L-1]}
  • 集合 {B[0] + q, B[1] + q, ..., B[M-1] + q}

との共通部分集合のサイズを出力せよ。

制約

  • 1 ≤ Q ≤ N ≤ 105

解法

想定解法ではないみたいだが、FFT がピッタリはまる。

数列 A, B をそれぞれ別の捉え方をして

  • A: ベクトル (  a_0, a_1, a_2, \dots, a_{N-1} ) ( a_{i} は数列 A に i+1 があるかどうか)
  • B: ベクトル (  b_0, b_1, b_2, \dots, b_{N-1} ) ( b_{i} は数列 B に i+1 があるかどうか)

としてみる。A = {1, 3, 4} のときはベクトルは (1, 0, 1, 1, 0, ...) になる。A と B はともにサイズ  N のベクトルである。こうすると、 q = 0 のときというのはこれらのベクトルの内積になる。 q \ge 1 の場合は B の各要素が右にスライドしたものと内積をとる感じになる。

さて、こういうスライドしていくベクトル同士の内積となると FFT を思い浮かべるのは恐らく典型だと推測される。多項式にしてみる:

  • A:  a_{0} + a_{1} x + a_{2} x^{2} + \dots + a_{N-1} x^{N-1}
  • B:  b_0 + b_1 x + b_2 x^{2} + \dots + b_{N-1} x^{N-1}

こうすると、 q のときは、多項式 A と B の次数の「」が  q のところの係数を掛け算したものを合計したものになる。このままだと畳み込みっぽくない。そこで B の係数をひっくり返してみる

  • B:  b_{N-1} + b_{N-2} x + \dots + b_{0} x^{N-1}

こうすると、 q のときは、多項式 A と B の次数の「」が  N-1+q のところの係数を掛け算したものを合計したものになる。これはすなわち、A と B を多項式として掛け算したものの次数を見ればいいことになる。

ここまで来れば A と B を多項式として掛け算するのを FFT によって  O(N\log{N}) で計算できる。

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

struct ComplexNumber {
    double real, imag;
    inline ComplexNumber& operator = (const ComplexNumber &c) {real = c.real; imag = c.imag; return *this;}
    friend inline ostream& operator << (ostream &s, const ComplexNumber &c) {return s<<'<'<<c.real<<','<<c.imag<<'>';}
};
inline ComplexNumber operator + (const ComplexNumber &x, const ComplexNumber &y) {
    return {x.real + y.real, x.imag + y.imag};
}
inline ComplexNumber operator - (const ComplexNumber &x, const ComplexNumber &y) {
    return {x.real - y.real, x.imag - y.imag};
}
inline ComplexNumber operator * (const ComplexNumber &x, const ComplexNumber &y) {
    return {x.real * y.real - x.imag * y.imag, x.real * y.imag + x.imag * y.real};
}
inline ComplexNumber operator * (const ComplexNumber &x, double a) {
    return {x.real * a, x.imag * a};
}
inline ComplexNumber operator / (const ComplexNumber &x, double a) {
    return {x.real / a, x.imag / a};
}

struct FFT {
    static const int MAX = 1<<18;               // must be 2^n
    ComplexNumber AT[MAX], BT[MAX], CT[MAX];

    void DTM(ComplexNumber F[], bool inv) {
        int N = MAX;
        for (int t = N; t >= 2; t >>= 1) {
            double ang = acos(-1.0)*2/t;
            for (int i = 0; i < t/2; i++) {
                ComplexNumber w = {cos(ang*i), sin(ang*i)};
                if (inv) w.imag = -w.imag;
                for (int j = i; j < N; j += t) {
                    ComplexNumber f1 = F[j] + F[j+t/2];
                    ComplexNumber f2 = (F[j] - F[j+t/2]) * w;
                    F[j] = f1;
                    F[j+t/2] = f2;
                }
            }
        }
        for (int i = 1, j = 0; i < N; i++) {
            for (int k = N >> 1; k > (j ^= k); k >>= 1);
            if (i < j) swap(F[i], F[j]);
        }
    }
    
    // C is A*B
    void mult(long long A[], long long B[], long long C[]) {
        for (int i = 0; i < MAX; ++i) AT[i] = {(double)A[i], 0.0};
        for (int i = 0; i < MAX; ++i) BT[i] = {(double)B[i], 0.0};
        
        DTM(AT, false);
        DTM(BT, false);
        
        for (int i = 0; i < MAX; ++i) CT[i] = AT[i] * BT[i];
        
        DTM(CT, true);
        
        for (int i = 0; i < MAX; ++i) {
            CT[i] = CT[i] / MAX;
            C[i] = (long long)(CT[i].real + 0.5);
        }
    }
};

int main() {
    int L, M, N, Q;
    cin >> L >> M >> N;
    long long A[FFT::MAX], B[FFT::MAX], C[FFT::MAX];
    for (int i = 0; i < L; ++i) {
        int a; cin >> a; A[a-1] = 1;
    }
    for (int i = 0; i < M; ++i) {
        int b; cin >> b; B[N-b] = 1;
    }
    FFT f;
    f.mult(A, B, C);
    
    cin >> Q;
    for (int i = 0; i < Q; ++i) {
        cout << C[N-1+i] << endl;
    }
}