けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Yosupo Library Checker - Queue Operate All Composite (2D)

SWAG を履修した!

問題概要

一次関数の列を考える。初期状態では空である。以下の  Q 個のクエリを処理せよ。

  • クエリタイプ 1 ( a, b):一次関数  ax + b を列の末尾に挿入する
  • クエリタイプ 2:列の先頭の要素を削除する
  • クエリタイプ 3 ( x):列を  f_{l}, f_{l+1}, \dots, f_{r-1} としたとき、 \displaystyle f_{r-1}(f_{r-2}(\dots f_{l}(x) \dots)) の値を 998244353 で割った余りを出力する

制約

  •  1 \le Q \le 5 \times 10^{5}

解法

モノイド (セグ木の要件と同じ) の列に対して、以下のクエリに  O(1) の計算量で答えられるデータ構造として SWAG がある (参考資料)。

  • 列の末尾に要素  v を挿入する
  • 列の先頭の要素を削除する
  • 現在の列全体の要素についての演算結果を返す

今回は、「一次関数の合成」を演算としたモノイド「一次関数」について、上記の処理を施せばよい。具体的には、 f(x) = ax + b,  g(x) = cx + d としたとき、

 g(f(x)) = c(ax + b) + d = (ac)x + (bc + d)

であることから、モノイド (型・演算・単位元) を次のように定めることとした。

// 998244353 で割った余りを管理するデータ構造
const int MOD = 998244353;
using mint = Fp<MOD>;
    
// モノイドの定義
using Monoid = pair<mint,mint>;
auto op = [&](Monoid x, Monoid y) {
    return Monoid(x.first * y.first, x.second * y.first + y.second);
};
Monoid identity = {1, 0};

// SWAG のセットアップ
SWAG<Monoid> sw(op, identity);

あとは、ひたすらクエリに答えればよい。計算量は  O(Q) と評価できる。

コード

#include <bits/stdc++.h>
using namespace std;

// SWAG
template<class Monoid> struct SWAG {
    using Func = function<Monoid(Monoid, Monoid)>;
    
    // core member
    Func OP;
    Monoid IDENTITY;
    
    // inner data
    int siz;
    vector<Monoid> dat_left, dat_right, sum_left, sum_right;
    
    // constructor
    SWAG() {}
    SWAG(const Func &op, const Monoid &identity) {
        init(op, identity);
    }
    SWAG(const vector<Monoid> &vec, const Func &op, const Monoid &identity) {
        init(vec, op, identity);
    }
    void init(const Func &op, const Monoid &identity) {
        OP = op;
        IDENTITY = identity;
        clear();
    }
    void init(const vector<Monoid> &vec, const Func &op, const Monoid &identity) {
        init(op, identity);
        for (const auto &v : vec) push_back(v);
    }
    void clear() {
        siz = 0;
        dat_left.clear(), dat_right.clear();
        sum_left = {IDENTITY}, sum_right = {IDENTITY};
    }
    
    // getter
    int size() { return siz; }
    
    // push
    void push_back(const Monoid &v) {
        ++siz;
        dat_right.emplace_back(v);
        sum_right.emplace_back(OP(sum_right.back(), v));
    }
    void push_front(const Monoid &v) {
        ++siz;
        dat_left.emplace_back(v);
        sum_left.emplace_back(OP(v, sum_left.back()));
    }
    
    // pop
    void rebuild() {
        vector<Monoid> tmp;
        for (int i = dat_left.size() - 1; i >= 0; --i) tmp.emplace_back(dat_left[i]);
        for (int i = 0; i < dat_right.size(); ++i) tmp.emplace_back(dat_right[i]);
        clear();
        int mid = tmp.size() / 2;
        for (int i = mid - 1; i >= 0; --i) push_front(tmp[i]);
        for (int i = mid; i < tmp.size(); ++i) push_back(tmp[i]);
        assert(siz == tmp.size());
    }
    void pop_back() {
        if (siz == 1) return clear();
        if (dat_right.empty()) rebuild();
        --siz;
        dat_right.pop_back();
        sum_right.pop_back();
    }
    void pop_front() {
        if (siz == 1) return clear();
        if (dat_left.empty()) rebuild();
        --siz;
        dat_left.pop_back();
        sum_left.pop_back();
    }
    
    // prod
    Monoid prod() {
        return OP(sum_left.back(), sum_right.back());
    }
    
    // debug
    friend ostream& operator << (ostream &s, const SWAG &sw) {
        for (int i = sw.dat_left.size() - 1; i >= 0; --i) s << sw.dat_left[i] << " ";
        for (int i = 0; i < sw.dat_right.size(); ++i) s << sw.dat_right[i] << " ";
        return s;
    }
};

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() : val(0) { }
    constexpr Fp(long long v) : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const { return val; }
    constexpr int get_mod() const { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator + () const { return Fp(*this); }
    constexpr Fp operator - () const { return Fp(0) - Fp(*this); }
    constexpr Fp operator + (const Fp &r) const { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const {
        return this->val != r.val;
    }
    constexpr Fp& operator ++ () {
        ++val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -- () {
        if (val == 0) val += MOD;
        --val;
        return *this;
    }
    constexpr Fp operator ++ (int) const {
        Fp res = *this;
        ++*this;
        return res;
    }
    constexpr Fp operator -- (int) const {
        Fp res = *this;
        --*this;
        return res;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) {
        return os << x.val;
    }
    friend constexpr Fp<MOD> pow(const Fp<MOD> &r, long long n) {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> inv(const Fp<MOD> &r) {
        return r.inv();
    }
};

int main() {
    const int MOD = 998244353;
    using mint = Fp<MOD>;
    
    // setup SWAG
    using Monoid = pair<mint,mint>;
    auto op = [&](Monoid x, Monoid y) {
        return Monoid(x.first * y.first, x.second * y.first + y.second);
    };
    Monoid identity = {1, 0};
    SWAG<Monoid> sw(op, identity);
    
    // queries
    int Q;
    scanf("%d", &Q);
    while (Q--) {
        int t, a, b, x;
        scanf("%d", &t);
        if (t == 0) {
            cin >> a >> b;
            sw.push_back(Monoid(a, b));
        } else if (t == 1) {
            sw.pop_front();
        } else {
            scanf("%d", &x);
            auto f = sw.prod();
            int res = (f.first * x + f.second).val;
            printf("%d\n", res);
        }
    }
}