けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

ACL Beginner Contest E - Replace Digits (青色, 500 点)

まさに遅延評価セグメント木の練習問題!!!

問題へのリンク

問題概要

長さ  N の文字列 S がある。 最初は  S のすべての文字が 1 である。以下の  Q 回のクエリに答えよ。

  • 各クエリは整数  L, R, D が与えられる ( D = 1, 2, \dots, 9)
  •  S L 番目から  R 番目までをすべて  D に書き換える
  •  S を数値とみなしたときの、998244353 で割ったあまりを求めよ

制約

  •  N, Q \le 10^{5}

考えたこと

遅延評価セグメント木とは

の両方を  O(\log N) で実施できることのあるデータ構造だ。特に、「区間に対する更新」ができるのがすごい。区間全体への更新は区間の長さ分の更新が必要に思えるので、一見するとどんなに頑張っても  O(N) 以上かかるように思えてしまう。しかし「遅延評価」という技術によって実現することができる。

さて、遅延評価セグメント木は「作用付きモノイド」を管理するデータ構造だとみなすことができる。

  • 各要素はモノイド (二項演算  \cdot が定義されたもの)
  • 各要素  v に対して「作用」を実施する関数  f(v) がある

という感じだ。ここで重要なこととして、モノイド  u, v に対して

 f(u \cdot v) = f(u) \cdot f(v)

を満たす必要がある。たとえば有名な Starry Sky Tree は

  • 区間に対して一律に値 x を加算
  • 区間の最小値を取得

というのを効率よく行えるデータ構造だが、

 \min(u + x, v + x) = \min(u, v) + x

が成立するが故に、遅延評価セグメント木として実現することができるのだ。

今回

今回は次のようにすると良さそう

モノイド

  • 区間に対応する値を 998244353 で割った値
  • 区間の長さ

のペアでもつ。このとき、モノイド間の二項演算は次のように定義できる。ここで、ten[ i ] は  10^{i} を 998244353 で割った値を表す。

// 二項演算
using pll = pair<mint, int>;
auto fm = [&](pll a, pll b) {
        mint first = a.first * ten[b.second] + b.first;
        int second = a.second + b.second;
        return pll(first, second);
};

作用

作用は区間に対して「値  d で書き換える」ことを意味する。作用は次のように表現できる。なお、作用の単位元は便宜的に 0 といった値に設定しておく。そして d = 0 のときはモノイドに対して何もしないようにする。

また、sum[ i ] は  \sum_{j = 0}^{i - 1} 10^{j} を 998244353 で割った値を表す。

// モノイドに対する作用
auto fa = [&](pll &a, int d) {
        if (d == 0) return; // d が単位元の場合は何もしない
        a.first = sum[a.second] * d;
};
// 作用の合成
auto fl = [&](int &d, int e) {
        d = e;
};

コード

以上の仕様で遅延評価セグメント木を作ることで、 O(N + Q \log N) で解ける。

#include <bits/stdc++.h>
using namespace std;


// Segment Tree
template<class Monoid, class Action> struct SegTree {
    using FuncMonoid = function< Monoid(Monoid, Monoid) >;
    using FuncAction = function< void(Monoid&, Action) >;
    using FuncLazy = function< void(Action&, Action) >;
    FuncMonoid FM;
    FuncAction FA;
    FuncLazy FL;
    Monoid IDENTITY_MONOID;
    Action IDENTITY_LAZY;
    int SIZE, HEIGHT;
    vector<Monoid> dat;
    vector<Action> lazy;
    
    SegTree() { }
    SegTree(int n, const FuncMonoid fm, const FuncAction fa, const FuncLazy fl,
            const Monoid &identity_monoid, const Action &identity_lazy)
    : FM(fm), FA(fa), FL(fl), 
      IDENTITY_MONOID(identity_monoid), IDENTITY_LAZY(identity_lazy) {
        SIZE = 1, HEIGHT = 0;
        while (SIZE < n) SIZE <<= 1, ++HEIGHT;
        dat.assign(SIZE * 2, IDENTITY_MONOID);
        lazy.assign(SIZE * 2, IDENTITY_LAZY);
    }
    void init(int n, const FuncMonoid fm, const FuncAction fa, const FuncLazy fl,
              const Monoid &identity_monoid, const Action &identity_lazy) {
        FM = fm, FA = fa, FL = fl;
        IDENTITY_MONOID = identity_monoid, IDENTITY_LAZY = identity_lazy;
        SIZE = 1; HEIGHT = 0;
        while (SIZE < n) SIZE <<= 1, ++HEIGHT;
        dat.assign(SIZE * 2, IDENTITY_MONOID);
        lazy.assign(SIZE * 2, IDENTITY_LAZY);
    }
    
    // set, a is 0-indexed
    void set(int a, const Monoid &v) { dat[a + SIZE] = v; }
    void build() {
        for (int k = SIZE - 1; k > 0; --k)
            dat[k] = FM(dat[k*2], dat[k*2+1]);
    }
    
    // update [a, b)
    inline void evaluate(int k) {
        if (lazy[k] == IDENTITY_LAZY) return;
        if (k < SIZE) FL(lazy[k*2], lazy[k]), FL(lazy[k*2+1], lazy[k]);
        FA(dat[k], lazy[k]);
        lazy[k] = IDENTITY_LAZY;
    }
    inline void update(int a, int b, const Action &v, int k, int l, int r) {
        evaluate(k);
        if (a <= l && r <= b) FL(lazy[k], v), evaluate(k);
        else if (a < r && l < b) {
            update(a, b, v, k*2, l, (l+r)>>1);
            update(a, b, v, k*2+1, (l+r)>>1, r);
            dat[k] = FM(dat[k*2], dat[k*2+1]);
        }
    }
    inline void update(int a, int b, const Action &v) { 
        update(a, b, v, 1, 0, SIZE);
    }
    
    // get [a, b)
    inline Monoid get(int a, int b, int k, int l, int r) {
        evaluate(k);
        if (a <= l && r <= b)
            return dat[k];
        else if (a < r && l < b)
            return FM(get(a, b, k*2, l, (l+r)>>1), 
                      get(a, b, k*2+1, (l+r)>>1, r));
        else
            return IDENTITY_MONOID;
    }
    inline Monoid get(int a, int b) { 
        return get(a, b, 1, 0, SIZE);
    }
    inline Monoid operator [] (int a) {
        return get(a, a + 1);
    }
    
    // debug
    void print() {
        for (int i = 0; i < SIZE; ++i) {
            if (i) cout << ",";
            cout << (*this)[i];
        }
        cout << endl;
    }
};


// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;
using pll = pair<mint,int>; // val, num

int main() {
    int N, Q;
    cin >> N >> Q;
    vector<mint> ten(N, 1), sum(N+1, 0);
    for (int i = 1; i < N; ++i) ten[i] = ten[i-1] * 10;
    for (int i = 0; i < N; ++i) sum[i+1] = sum[i] + ten[i];

    // define segtree
    auto fm = [&](pll a, pll b) {
        mint first = a.first * ten[b.second] + b.first;
        int second = a.second + b.second;
        return pll(first, second);
    };
    auto fa = [&](pll &a, int d) {
        if (d == 0) return;
        a.first = sum[a.second] * d;
    };
    auto fl = [&](int &d, int e) {
        d = e;
    };
    SegTree<pll, int> seg(N, fm, fa, fl, pll(mint(0), 0), 0);

    // initialization
    for (int i = 0; i < N; ++i) seg.set(i, pll(mint(1), 1));
    seg.build();

    // query
    while (Q--) {
        int l, r, d;
        cin >> l >> r >> d;
        --l;
        seg.update(l, r, d);
        cout << seg.get(0, N).first << endl;
    }
}