けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

フォルシアゆるふわ競プロオンサイト #3 C - Bananas Multiplier (Hard) locked

ツリー上のパスクエリについての教育的典型題だった

問題へのリンク

問題概要 (意訳)

 N 頂点の重み付きツリーが与えられる。以下の  Q 個のクエリに答えよ。

  • 各クエリは 2 頂点  m,  p と値  x が指定され、ツリー上の  m p とを結ぶパス上の辺の重みの積を求め、それと  x を掛け合わせた数を 1000000007 で割ったあまりを出力せよ

制約

  •  N, Q \le 10^{5}

考えたこと

この問題はつまり、ツリーの 2 頂点 s, t を結ぶパスの重みの積を高速に求めてください、という問題ということになる。

僕は LCA (最小共通祖先) を求める方針で解くことにした。まず前処理として、ツリーの根を適当に定めて根付き木として、

  • mul[ v ] := 根 r から頂点 v へ至るパス上の辺の重みの積を 1000000007 で割ったあまり

を求めておくことにした。そうしておけば、各クエリに対して、

  • s と t の LCA を求める (g とする)
  • s-g パスについての辺の重みの積 sg は、sg = mul[ s ] / mul[ g ]
  • t-g パスについての辺の重みの積 tg は、tg = mul[ t ] / mul [ g ]
  • よって、s-t パスについての辺の重みの積は sg × tg

という風に求めることができる。上記のことは、たとえば実際には t が s の先祖であって、s と t の LCA が t に一致するような場合であっても成立する (その場合 tg = 1 となる)。

LCA の求め方は、ダブリングを用いる方法などが蟻本に載っている。計算量は  O(N + Q\log{N})

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;
using Edge = pair<int, mint>;
using Graph = vector<vector<Edge>>;

// input
int N;
Graph G;
vector<mint> mul;

// LCA and so on
const int MAX_DEPTH = 20;
vector<int> depth;
vector<vector<int>> parent;

int getLCA(int s, int t) {
    if (depth[s] > depth[t]) swap(s, t);
    for (int i = MAX_DEPTH-1; i >= 0; --i) {
        auto nt = parent[i][t];
        if (nt == -1 || depth[nt] < depth[s]) continue;
        t = nt;
    }
    if (s == t) return s;
    for (int i = MAX_DEPTH-1; i >= 0; --i) {
        auto ns = parent[i][s], nt = parent[i][t];
        if (ns == -1 || nt == -1 || ns == nt) continue;
        s = ns, t = nt;
    }
    return parent[0][s];
}

void dfs(int v, int p, int d, mint m) {
    depth[v] = d;
    mul[v] = m;
    parent[0][v] = p;
    for (auto e : G[v]) {
        if (e.first == p) continue;
        dfs(e.first, v, d + 1, m * e.second);
    }
}

void preprocess() {
    depth.assign(N, 0);
    mul.assign(N, 1);
    parent.assign(MAX_DEPTH, vector<int>(N, -1));
    dfs(0, -1, 0, 1);
    for (int i = 0; i + 1 < MAX_DEPTH; ++i) {
        for (int v = 0; v < N; ++v) {
            parent[i+1][v] = parent[i][parent[i][v]];
        }
    }
}

int main() {
    cin >> N;
    G.assign(N, vector<Edge>());
    for (int i = 0; i < N-1; ++i) {
        int u, v, c;
        cin >> u >> v >> c;
        --u, --v;
        G[u].emplace_back(v, c);
        G[v].emplace_back(u, c);
    }
    preprocess();

    int Q; cin >> Q;
    while (Q--) {
        int s, t, x;
        cin >> s >> t >> x;
        --s, --t;
        int g = getLCA(s, t);
        cout << mul[s] * mul[t] / (mul[g] * mul[g]) * x << endl;
    }
}