けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 278 D - All Assign Point Add (茶色, 400 点)

データの持ち方をうまく工夫することで、計算量を改善する系の問題!

問題概要

長さ  N の数列  A_{1}, A_{2}, \dots, A_{N} が与えられる。 Q 個のクエリが与えられるので、それらを順に処理せよ。クエリは次の 3 種類ある。

  1. x:数列  A をすべて  x に書き換える
  2. i x A_{i} x を足す
  3. i A_{i} の値を出力する

制約

  •  1 \le N, Q \le 2 \times 10^{5}

この問題の何が難しいか

まず、もし仮に何も工夫せずに愚直にクエリを処理したら、計算量がどうなるかを見積もってみましょう。

クエリタイプ 2, 3 については、愚直にやっても  O(1) の計算量で済みます。

しかし、クエリタイプ 1 がやばくて、 N 個の値をすべて書き換えるため、それだけで  O(N) の計算量を要します。

もしクエリタイプ 1 が沢山含まれていると、最悪計算量は  O(QN) となります。これは TLE となります。

クエリタイプ 1 では、「全部  x にするよ」という情報のみを持つ

そこで、クエリタイプ 1 では、愚直にすべての  A_{i} x に書き換えるのではなく、全て書き換えたことを示す 1 つの変数 base を用意して

base = x;

というように処理することにしましょう。

これに伴い、これまで問題なかったクエリタイプ 2, 3 にも工夫が必要となります。たとえばクエリタイプ 2 を処理しようとして  A_{i} x を足そうとするとき、 A_{i} の真の値はそこにはないかもしれないのです。

 A_{i} の値は、base になっているかもしれないし、他の値になっているかもしれません。

そこで、次のデータを持つことにします (これは C++ の場合です。Python3 の場合は辞書型を使うとよいでしょう)。


  • map<int,long long> add

add[i] は、 A_{i} の値が base からどれだけ離れているかを示す


たとえば、次の順にクエリが来たとしましょう。

  • クエリタイプ 1:7
  • クエリタイプ 2:1 5
  • クエリタイプ 2:2 3
  • クエリタイプ 2:1 4

これらをこなすと、baseadd の値は次のようになります。

  • `base = 7'
  • add[1] = 9, add[2] = 3

すでに add が溜まった状態でクエリタイプ 1 が来るとき

最後の問題は、すでに add が溜まった状態で、クエリタイプ 1 が襲ってくることです。

このとき、数列  A の各要素が今までどんな歴史を辿って来たかに関係なく、一律に値  x にセットされます。add の情報はすべて破棄することになるのです。具体的には次のように実装すればよいでしょう。

base = x;
add.clear();

計算量が気になるかもしれません。しかし、add の要素が増える瞬間の個数はすべて総和をとっても  Q 回以下なのです。

ですので、add.clear() を何回実施したとしても、それによって消す要素の個数の総和は述べ  Q 個以下となります。よって問題ありません。

全体の計算量は、 O((N+Q)\log N) となります。 O(\log N) がつくのは、map 型の操作に  O(\log N) の計算量を要するからです。

詳細は、次のコードを見れば色々わかるはずです。

コード

C++

#include <bits/stdc++.h>
using namespace std;

int main() {
    // クエリ処理のためのデータ構造
    long long base = 0;
    map<int, long long> add;
    
    // 入力を受け取る
    // 初期状態の A[i] の値も add に含めてしまうことにする
    int N, Q;
    cin >> N;
    for (int i = 0; i < N; ++i) {
        long long A; cin >> A;
        add[i] += A;
    }
    
    // クエリを処理していく
    cin >> Q;
    for (int query = 0; query < Q; ++query) {
        int type;
        cin >> type;
        if (type == 1) {
            int x; cin >> x;
            
            // base を x にする
            base = x, add.clear();
        } else if (type == 2) {
            int i, x; cin >> i >> x;
            --i;
            
            // add を更新する
            add[i] += x;
        } else {
            int i; cin >> i;
            --i;
            
            // base からの差分を足して出力する
            cout << base + add[i] << endl;
        }
    }
}

Python3

# クエリ処理のためのデータ構造
base = 0;
add = defaultdict(int)
    
# 入力を受け取る
N = int(input())
A = list(map(int, input().split()))
for i, a in enumerate(A):
    add[i] += a

# クエリを処理していく
Q = int(input())
for query in range(Q):
    type, *q = map(int, input().split())
    if type == 1:
        x = q[0]
        base = x
        add = defaultdict(int)
    elif type == 2:
        i, x = q[0]-1, q[1]
        add[i] += x
    else:
        i = q[0]-1
        print(base + add[i])