けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codechef Product on the segment by modulo

Disjoint Sparse Table の勉強をした!!!

詳しいことは

あたりを参考に。

問題概要

数列が与えられて、「区間 [L, R) の積を P で割ったあまりを答えよ」というクエリに大量に答えよ。

解法

クエリ回数が多いので、

  • セグツリー: 構築  O(N), クエリ  O(\log{N})

よりも

  • sparse table: 構築  O(N\log{N}), クエリ  O(1)

を使った方がいい。ただ元来の sparse table は


交叉半束

集合  S S 上の二項演算 + があって


を満たすものに適用される。今回はべき等則を満たさないので、そのままでは使えない。アイディアは似つつもそれを


半群

集合  S S 上の二項演算 + があって


を満たすものに適用できるように拡張された disjoint sparse table を使うとピッタリ。

#include <iostream>
#include <vector>
#include <functional>
using namespace std;


template<class SemiGroup> struct DisjointSparseTable {
    using Func = function< SemiGroup(SemiGroup, SemiGroup) >;
    const Func F;
    vector<vector<SemiGroup> > dat;
    vector<int> height;
    
    DisjointSparseTable(const Func &f) : F(f) { }
    DisjointSparseTable(const Func &f, const vector<SemiGroup> &vec) : F(f) { init(vec); }
    void init(const vector<SemiGroup> &vec) {
        int n = (int)vec.size(), h = 1;
        while ((1<<h) <= n) ++h;
        dat.assign(h, vector<SemiGroup>(n));
        height.assign((1<<h), 0);
        for (int i = 2; i < (1<<h); i++) height[i] = height[i>>1]+1;
        for (int i = 0; i < n; ++i) dat[0][i] = vec[i];
        for (int i = 1; i < h; ++i) {
            int s = (1<<i);
            for (int j = 0; j < n; j += (s << 1)) {
                int t = min(j+s, n);
                dat[i][t-1] = vec[t-1];
                for (int k = t-2; k >= j; --k) dat[i][k] = F(vec[k], dat[i][k+1]);
                if (n <= t) break;
                dat[i][t] = vec[t];
                for (int k = t+1; k < min(t+s, n); ++k) dat[i][k] = F(dat[i][k-1], vec[k]);
            }
        }
    }
    SemiGroup get(int a, int b) {
        if (a >= --b) return dat[0][a];
        return F(dat[height[a^b]][a], dat[height[a^b]][b]);
    }
};



int main() {
    int T; scanf("%d", &T);
    int N, P, Q;
    DisjointSparseTable<long long> dst([&](long long a, long long b){return a*b%P;});
    for (int CASE = 0; CASE < T; ++CASE) {
        scanf("%d %d %d", &N, &P, &Q);
        vector<long long> v(N);
        for (int i = 0; i < N; ++i) scanf("%lld", &v[i]);
        dst.init(v);
        vector<int> b(Q/64+2);
        for (int i = 0; i < (int)b.size(); ++i) scanf("%d", &b[i]);
        int x = 0, L = 0, R = 0;
        for (int query = 0; query < Q; ++query) {
            if (query % 64 == 0) L = (b[query/64] + x) % N, R = (b[query/64+1] + x) % N;
            else L = (L + x) % N, R = (R + x) % N;
            if (L > R) swap(L, R);
            x = (dst.get(L, R+1) + 1) % P;
        }
        printf("%d\n", x);
    }
}