けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 369 C - Count Arithmetic Subarrays (3Q, 灰色, 300 点)

とても教育的で典型的なしゃくとり法の問題!

問題概要

 N 個の整数からなる数列  A_{1}, A_{2}, \dots, A_{N} が与えられる。

 A_{l}, A_{l+1}, \dots, A_{r} が等差数列であるような組  (l, r) の個数を求めよ。

制約

  •  1 \le N \le 2 \times 10^{5}
  •  1 \le A_{i} \le 10^{9}

解法 (1):しゃくとり法

今回の問題のように、数列中で条件を満たす区間を考える問題では、しゃくとり法が使えることが多い。

特に今回は、ある区間が条件を満たすとき、それに含まれる区間も条件を満たす。このような場合は、しゃくとり法がドンピシャである。しゃくとり法については、次の記事を参照。

qiita.com

しゃくとり法を用いて、次のコードのように書ける。計算量は  O(N) である。

コード

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N;
    cin >> N;
    vector<long long> A(N);
    for (int i = 0; i < N; i++) cin >> A[i];

    // 区間 [l, r) が条件を満たすかを判定(区間 [l, r-1) は条件を満たすとする)
    auto check = [&](int l, int r) -> bool {
        if (r - l <= 2) return true;
        return A[r-1] - A[r-2] == A[r-2] - A[r-3];
    };

    long long res = 0, right = 0;
    for (long long left = 0; left < N; left++) {
        if (right < left) right = left;
        while (right < N && check(left, right + 1)) right++;
        res += right - left;
    }
    cout << res << endl;
}

 

解法 (2):階差数列を考える

今回の問題のように、数列の隣接要素の差を考えたいような問題では、階差数列を考えることでうまくいくことがある。たとえば

 A = (3, 6, 9, 3)

に対して、階差数列  B を求めると、次のようになる。

 B = (3, 3, -6)

この階差数列  B から、今回の問題は次のように解釈できる。

  • 階差が 3 である部分( B_{1}, B_{2})に対応するのは、 (A_{1}, A_{2}, A_{3}) である
    • これらの隙間と両端を合わせた 4 箇所から 2 箇所選ぶ方法が適する
    • それは、 {}_{4}\mathrm{C}_{2} = 6 通りある
  • 階差が -6 である部分( B_{3})に対応するのは、 (A_{3}, A_{4}) である
    • これらの隙間と両端を合わせた 3 箇所から 2 箇所選ぶ方法が適する
    • それは、 {}_{3}\mathrm{C}_{2} = 3 通りある
  • これらのうち  A_{3} のみからなる数列が被っている
    • よって、答えは  6 + 3 - 1 = 8 通りである

以上の考察を一般化すると、階差数列をランレングス圧縮したときに、次のように解けることがわかる。計算量は  O(N) である。


  • 圧縮される各区間の個数を  n として  {}_{n+2}\mathrm{C}_{2} 通りを足していく
  • 最後に、圧縮される区間の個数を  m として、 m-1 を引く(重複分が  m-1 個ある)

コード

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N;
    cin >> N;
    vector<long long> A(N), B(N-1);
    for (int i = 0; i < N; i++) cin >> A[i];
    for (int i = 0; i + 1 < N; i++) B[i] = A[i+1] - A[i];

    long long res = 0, num = 0;
    for (int i = 0; i < B.size(); ) {
        int j = i;
        while (j < B.size() && B[j] == B[i]) j++;
        
        long long n = j - i;
        res += (n + 2) * (n + 1) / 2;
        num++;  // 区間の個数

        i = j;
    }
    cout << res - (num - 1) << endl;
}