けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

yukicoder No.980 Fibonacci Convolution Hard (母関数)

すごい面白かった!!!

問題へのリンク

問題概要

整数  p が与えらえたときに、漸化式

  •  a_{1} = 0
  •  a_{2} = 1
  •  a_{n} = p a_{n-1} + a_{n-2}

を満たす数列  a_{n} が与えられる。これに対して以下の  Q 個のクエリに答えよ:

  • 1 つのクエリは 2 以上の整数  q が指定され、 q = s + t を満たすような正の整数  s, t に対して、 a_{s} \times a_{t} の総和を求め、1000000007 で割ったあまりを求めよ。

制約

  •  1 \le p \le 10^{9}
  •  1 \le Q \le 2 \times 10^{5}
  •  2 \le q_{i} \le 2 \times 10^{6}

数列の母関数

いかにも畳み込みっぽい式をしているので、母関数 (形式的べき級数) で考えてみたくなる。さて、線形漸化式の形で書かれた数列の母関数は実は簡単に求められる。

まず数列  a_{0}, a_{1}, \dots の母関数とは、

 f(x) = a_{0} + a_{1} x + a_{2} x^{2} + \dots

で定義される関数のこと。求めたい数列に対して、母関数がわかれば、その係数を見ることで数列の各項がわかる、みたいなことがよくある。

さて、たとえば例として三項間漸化式

  •  a_{0} = s
  •  a_{1} = t
  •  a_{n} + p a_{n-1} + q a_{n-2} = 0

で定義される数列の母関数を形式的に求めてみよう。

  •  f(x) = a_{0} + a_{1} x + a_{2} x^{2} + a_{3}x^{3} + a_{4}x^{4} + \dots
  •  px f(x) = pa_{0}x + pa_{1}x^{2} + pa_{2}x^{3} + pa_{3}x^{4} + \dots
  •  qx^{2}f(x) = qa_{0}x^{2} + qa_{1}x^{3} + qa_{2}x^{4} + \dots

として、これらを足すと、なんとほとんどの項が消えてしまうのだ。なぜなら漸化式の定義から、

  •  a_{2} + p a_{1} + q a_{0} = 0
  •  a_{3} + p a_{2} + q a_{1} = 0
  • ...

が成立するからだ。足すと

 (1 + px + qx^{2})f(x) = a_{0} + (a_{1} + p a_{0})x
 f(x) = \frac{s + (ps + t)x}{1 + px + qx^{2}}

と求められる。

今回の数列

今回の数列 (0-indexed にして、クエリの  q 2 引いておく) の母関数を  f(x) とすると、

 f(x) = \frac{x}{1 - px - x^{2}}

と求められる。ここで、

  •  b_{i} = \sum_{j = 0}^{i} a_{j} \times a_{i - j}

によって定義される数列  b_{i} の母関数  g(x) を求めることにする。数列  b_{i} の定義をよくよく考えると、

 (a_{0} + a_{1} x + a_{2} x^{2} + \dots)(a_{0} + a_{1} x + a_{2} x^{2} + \dots)

の各項の係数になっていることがわかる。よって単純に

 g(x) = f(x)f(x)

なのだ!!!実際に計算してみると、

 g(x) = \frac{x^{2}}{1 - 2px + (p^{2} - 2)x^{2} + 2p x^{3} + x^{4}}

となる。ちなみに分子の  x^{2} は恐れる必要はない。母関数を  x 倍したり  x で割ったりするのは、対応する数列の添字をずらす操作に他ならない。 g(x) は、むしろ母関数としては

 g'(x) = \frac{x^{3}}{1 - 2px + (p^{2} - 2)x^{2} + 2p x^{3} + x^{4}}

と表されるものがわかりやすくて、これを漸化式に直すと

  •  b_{0} = 0
  •  b_{1} = 0
  •  b_{2} = 0
  •  b_{3} = 1
  •  b_{n} = 2p b_{n-1} - (p^{2}-2) b_{n-2} - 2p b_{n-3} - b_{n-4}

となる!!!実際はこれに  x を割るので、項が 1 個だけずれることになる。よって数列  b_{i} は、


  •  b_{0} = 0
  •  b_{1} = 0
  •  b_{2} = 1
  •  b_{3} = 2p
  •  b_{n} = 2p b_{n-1} - (p^{2}-2) b_{n-2} - 2p b_{n-3} - b_{n-4}

で求められるので、あらかじめ制約の  2\times 10^{6} 程度まで求めておけば OK!!!

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MAX = 2100000;
const int MOD = 1000000007;
using mint = Fp<MOD>;

int main() {
    long long lp;
    cin >> lp;
    vector<mint> b(MAX, 0);
    mint p = lp;
    b[2] = 1, b[3] = p * 2;
    for (int n = 4; n < MAX; ++n) {
        b[n] = b[n-1]*p*2 - b[n-2]*(p*p-2) - b[n-3]*p*2 - b[n-4];
    }
    int Q; cin >> Q;
    for (int _ = 0; _ < Q; ++_) {
        int q; cin >> q;
        q -= 2;
        cout << b[q] << endl;
    }
}