けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 106 D - Powers (青色, 600 点)

「要素を 1 個ずつ追加していくときに値がどう変化していくか」を観察する方向でずっと考えていて迷走してしまった...

問題概要

正の整数  K と、 N 個の整数  A_{1}, \dots, A_{N} が与えられる。 X = 1, 2, \dots, K に対して、

 \sum_{i=1}^{N-1} \sum_{j=i+1}^{N-1} (A_{i} + A_{j})^{X}

の値を 998244353 で割ったあまりを求めよ。

制約

  •  2 \le N \le 2 \times 10^{5}
  •  1 \le K \le 300

解法

 N 個のうち  2 個を選ぶ  {}_{N}{\rm C}_{2} 通りのペアについての総和を求める問題となっている。こういうのは確かに

  •  N 個のうちから重複も許して  2 個選んだ場合を求めて
  •  N 個のうちから重複して  2 個選ぶ場合を引いて
  • 最後に 2 で割る

という考え方が定石なのかもしれない。そっちの方向に行けなかった...。試しに  N = 3 x = 4 としてみる。このとき  3 \times 3 = 9 個の和になるが、そのうちの  (a + ?)^{x} の部分だけを抽出してみる。そうすると、

  •  (a+a)^{4} = {}_{4}{\rm C}_{0}a^{4}a^{0} + {}_{4}{\rm C}_{1}a^{3}a^{1} + {}_{4}{\rm C}_{2}a^{2}a^{2} + {}_{4}{\rm C}_{3}a^{1}a^{3} + {}_{4}{\rm C}_{4}a^{0}a^{4}
  •  (a+b)^{4} = {}_{4}{\rm C}_{0}a^{4}b^{0} + {}_{4}{\rm C}_{1}a^{3}b^{1} + {}_{4}{\rm C}_{2}a^{2}b^{2} + {}_{4}{\rm C}_{3}a^{1}b^{3} + {}_{4}{\rm C}_{4}a^{0}b^{4}
  •  (a+c)^{4} = {}_{4}{\rm C}_{0}a^{4}c^{0} + {}_{4}{\rm C}_{1}a^{3}c^{1} + {}_{4}{\rm C}_{2}a^{2}c^{2} + {}_{4}{\rm C}_{3}a^{1}c^{3} + {}_{4}{\rm C}_{4}a^{0}c^{4}

となっている。コンビネーションのところは  {}_{n}{\rm C}_{r} = \frac{n!}{r!(n-r)!} を使うと、かなり綺麗になる。

  •  \frac{(a+a)^{4}}{4!} = \frac{a^{4}}{4!} \frac{a^{0}}{0!} + \frac{a^{3}}{3!} \frac{a^{1}}{1!} + \frac{a^{2}}{2!} \frac{a^{2}}{2!} + \frac{a^{1}}{1!} \frac{a^{3}}{3!} + \frac{a^{0}}{0!} \frac{a^{4}}{4!}
  •  \frac{(a+b)^{4}}{4!} = \frac{a^{4}}{4!} \frac{b^{0}}{0!} + \frac{a^{3}}{3!} \frac{b^{1}}{1!} + \frac{a^{2}}{2!} \frac{b^{2}}{2!} + \frac{a^{1}}{1!} \frac{b^{3}}{3!} + \frac{a^{0}}{0!} \frac{b^{4}}{4!}
  •  \frac{(a+c)^{4}}{4!} = \frac{a^{4}}{4!} \frac{c^{0}}{0!} + \frac{a^{3}}{3!} \frac{c^{1}}{1!} + \frac{a^{2}}{2!} \frac{c^{2}}{2!} + \frac{a^{1}}{1!} \frac{c^{3}}{3!} + \frac{a^{0}}{0!} \frac{c^{4}}{4!}


さて、 a'_{i} = \frac{a^{i}}{i!} と表記することにすると、次のようにまとめられることがわかる。


 \frac{(a+a)^{4}}{4!} + \frac{(a+b)^{4}}{4!} + \frac{(a+c)^{4}}{4!}
 = a'_{4}(a'_{0} + b'_{0} + c'_{0})
 + a'_{3}(a'_{1} + b'_{1} + c'_{1})
 + a'_{2}(a'_{2} + b'_{2} + c'_{2})
 + a'_{1}(a'_{3} + b'_{3} + c'_{3})
 + a'_{0}(a'_{4} + b'_{4} + c'_{4})

さらに、 (b + ?)^{4} (c + ?)^{4} についても考えると、対称性から


 \frac{(a+a)^{4}}{4!} + \frac{(a+b)^{4}}{4!} + \frac{(a+c)^{4}}{4!} + \frac{(b+a)^{4}}{4!} + \frac{(b+b)^{4}}{4!} + \frac{(b+c)^{4}}{4!} + \frac{(c+a)^{4}}{4!} + \frac{(c+b)^{4}}{4!} + \frac{(c+c)^{4}}{4!}
 = (a'_{4} + b'_{4} + c'_{4})(a'_{0} + b'_{0} + c'_{0})
 + (a'_{3} + b'_{3} + c'_{3})(a'_{1} + b'_{1} + c'_{1})
 + (a'_{2} + b'_{2} + c'_{2})(a'_{2} + b'_{2} + c'_{2})
 + (a'_{1} + b'_{1} + c'_{1})(a'_{3} + b'_{3} + c'_{3})
 + (a'_{0} + b'_{0} + c'_{0})(a'_{4} + b'_{4} + c'_{4})

となる。よって、

  •  a'_{k} + b'_{k} + c'_{k}

をそれぞれ前処理で求めておけば、 x = 1, 2, \dots, K の場合をすべて求める作業を  O(K^{2}) でできることになる。

以上のことは一般の場合にも拡張できて、全体の計算量は  O(NK + K^{2}) となる。

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

// Binomial Coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N, K;
    cin >> N >> K;
    BiCoef<mint> bc(K+1);
    vector<mint> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    vector<mint> S(K+1, 0), powA(N, 1);
    for (int k = 0; k <= K; ++k) {
        for (int i = 0; i < N; ++i) {
            S[k] += powA[i];
            powA[i] *= A[i];
        }
        S[k] *= bc.finv(k);
    }
    for (int k = 1; k <= K; ++k) {
        mint res = 0;
        for (int i = 0; i <= k; ++i) res += S[i] * S[k-i];
        res = (res - S[k] * modpow(mint(2), k)) * bc.fact(k) / 2;
        cout << res << endl;
    }
}

 

解法 (2):NTT で高速化

上のコードで、

for (int k = 1; k <= K; ++k) {
    mint res = 0;
    for (int i = 0; i <= k; ++i) res += S[i] * S[k-i];
}

という処理をしている。ここは NTT を用いて高速化できる!それをすることで  O(NK + K \log K) の計算量となる。

 

N, K ともに巨大でも

 N, K \le 10^{5} であっても、次の yukicoder の問題の知見を活用することで解ける模様。

drken1215.hatenablog.com

 

コンテスト中に考えていたこと

 N 個の値を対等に扱う思考に入れずに、ひたすら「1 個追加するとどうなるか」を考える方向に走ってしまった。

畳み込み計算も駆使して、 O(NK \log K) までにはなったのだけど、間に合わなかった。