けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 378 E - Mod Sigma Problem (1D, 水色, 475 点)

これ面白かった!

問題概要

数列  A_{1}, A_{2}, \dots, A_{N} が与えられる。この数列の連続する部分数列について「その総和を  M で割った余り」を考える。

連続する部分数列をすべて考えたときの、「その総和を  M で割った余り」の総和を求めよ。

制約

  •  1 \le N, M \le 2 \times 10^{5}
  •  0 \le A_{i} \le 10^{9}

考えたこと

この問題のように「数列の部分数列の総和」を考えるときには、累積和を考えるとよいと相場は決まっている(俗に言う Zero-Sum Ranges 法)。

数列  A の累積和を  S_{0}, S_{1}, \dots, S_{N} としよう(さらに、各項  S_{i} M で割った余りに置き換えておく)。これら  N+1 個の値から 2 つ  S_{x}, S_{y} ( x \lt y) を選んで、

 S_{y} - S_{x}  \mathrm{mod}  M

の値の総和を求めればよい。ここで、 S_{x}, S_{y} の大小関係を考えると、上の値は次のように整理できる。

  •  S_{x} \le S_{y} のとき: S_{y} - S_{x}
  •  S_{x} \gt S_{y} のとき: (S_{y} - S_{x}) + M

ここで、 S_{x} \gt S_{y} のときの「 +M」の部分は別途求めることにしょう。そうすると、結局次の値を求めればよいことになる。


  •  0 \le x \lt y \le N についての  S_{y} - S_{x} の総和を  P とする
  • 数列  S の転倒数を  Q とする

このとき、答えは

 P + Q \times M

で求められる。


 x を固定して考えよう。 T_{x} = S_{x+1} + S_{x+2} + \dots + S_{N} とすると、次のように求められる。

 T_{x} - S_{x}(N - x)

 x = N, N-1, \dots, 0 について、この値の総和を求めれば良い。

以上まとめて、 O(N \log N) の計算量で答えが求められる。

コード

#include <bits/stdc++.h>
using namespace std;

// BIT
template <class Abel> struct BIT {
    Abel UNITY_SUM = 0;
    vector<Abel> dat;
    
    // [0, n)
    BIT(int n, Abel unity = 0) : UNITY_SUM(unity), dat(n, unity) { }
    void init(int n) {
        dat.assign(n, UNITY_SUM);
    }
    
    // a is 0-indexed
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i |= i + 1)
            dat[i] = dat[i] + x;
    }
    
    // [0, a), a is 0-indexed
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a - 1; i >= 0; i = (i & (i + 1)) - 1)
            res = res + dat[i];
        return res;
    }
    
    // [a, b), a and b are 0-indexed
    inline Abel sum(int a, int b) {
        return sum(b) - sum(a);
    }
    
    // debug
    void print() {
        for (int i = 0; i < (int)dat.size(); ++i)
            cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};

int main() {
    long long N, M;
    cin >> N >> M;
    vector<long long> A(N), S(N+1, 0);
    for (int i = 0; i < N; i++) {
        cin >> A[i];
        S[i+1] = (S[i] + A[i]) % M;
    }

    // 転倒数
    long long inversion_number = 0;
    BIT<long long> bit(M + 1, 0);
    for (int i = N; i >= 0; i--) {
        inversion_number += bit.sum(0, S[i]);
        bit.add(S[i], 1);
    }

    // 転倒数以外の部分を足す
    long long res = inversion_number * M, sum = 0;
    for (int i = N; i >= 0; i--) {
        res += sum - S[i] * (N - i);
        sum += S[i];
    }
    cout << res << endl;
}