けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

2018 codeFlyer 本選 E - 数式とクエリ (700 点)

構文解析、超絶苦手系だけど苦手とばかり言っていられない。

問題へのリンク

問題概要

(a)*a+((a+(a*(a))-(a)*a+a*a))*a

のような文字列  S が与えられる。各 a に入るデフォルトの数値  a_1, a_2, \dots, a_N が与えられている。今  Q 個のクエリが来て、各クエリは

  •  b, x b 個目の a を  x で置き換えたときの数式の値を 1000000007 で割ったあまりを求めよ

というものである。

制約

  •  S の長さは 200000 以下
  •  1 \le Q \le 10^{5}

解法

まず重要な観察として、

  • 各 a の値に対する、全体結果の値の応答は線形である

したがって、各 a を 1 増やしたときに全体の応答がどう変化するかを求められればよい。

まず与えられた文字列を構文解析する。そうすると、こんな感じの二分木を作ることができる:

ここから木 DP っぽいことをやっていく。

dp[v] := v を根とする部分木について、v の値が 1 増えたとしたら全体結果は幾つ増えるか

という値を持っておく。通常の木 DP では葉から根へと伝播していくが、今回は向きが反対で、根から葉へと伝播していく感じになる。v の演算子によって左右の子ノードの dp 値がどうなるかを考えていく。

例えば dp[v] = w で v の演算子が「-」のとき、v の右ノードが 1 増えれば、v のノードが -1 増えることになるので、全体の値が -w 増えることになる。したがって、dp[right[v]] = -w である。

  • dp[v] = w であったとき
v の演算子 左ノード 右ノード
+ w w
- w -w
* w*(右ノードのデフォルト値) w*(左ノードのデフォルト値)

こういう感じで伝播していけばよい。各クエリに対する答えは


(ルートのデフォルト値) + (x - 指定ノードのデフォルト値) * dp[指定ノード]


という感じになる。

#include <iostream>
#include <vector>
#include <string>
using namespace std;

const int MOD = 1000000007;
inline long long mod(long long n) {
    n %= MOD;
    if (n < 0) n += MOD;
    return n;
}

vector<long long> given;
template<class T> struct Parser {
    // results
    int root;                       // vals[root] is the answer
    vector<T> vals;                 // value of each node
    vector<char> ops;               // operator of each node ('a' means leaf values)
    vector<int> left, right;        // the index of left-node, right-node
    
    // member to solve each problem
    vector<int> ids; // 何番目の a か
    int id = 0;
    
    // generate nodes
    int newnode(char op, int lp, int rp, T val = 0) {
        ops.push_back(op); left.push_back(lp); right.push_back(rp);
        if (op == 'a') {
            vals.push_back(val);
            ids.push_back(id++);
        }
        else {
            if (op == '+') vals.push_back(mod(vals[lp] + vals[rp]));
            else if (op == '-') vals.push_back(mod(vals[lp] - vals[rp]));
            else if (op == '*') vals.push_back(mod(vals[lp] * vals[rp]));
            ids.push_back(-1);
        }
        return (int)vals.size() - 1;
    }
    
    // main solver
    T solve(const string &S) {
        int p = 0;
        root = expr(S, p);
        return vals[root];
    }
    
    // parser
    int expr(const string &S, int &p) {
        int lp = factor(S, p);
        while (p < (int)S.size() && (S[p] == '+' || S[p] == '-')) {
            char op = S[p]; ++p;
            int rp = factor(S, p);
            lp = newnode(op, lp, rp);
        }
        return lp;
    }
    
    int factor(const string &S, int &p) {
        int lp = value(S, p);
        while (p < (int)S.size() && (S[p]== '*' || S[p] == '/')) {
            char op = S[p]; ++p;
            int rp = value(S, p);
            lp = newnode(op, lp, rp);
        }
        return lp;
    }
    
    int value(const string &S, int &p) {
        if (S[p] == '(') {
            ++p;                    // skip '('
            int lp = expr(S, p);
            ++p;                    // skip ')'
            return lp;
        }
        else {
            /* each process */
            while (S[p] == 'a') ++p;
            T val = given[id];
            return newnode('a', -1, -1, val);
        }
    }
};

Parser<long long> ps;

// それぞれの a が 1 変わったらどうなるかを調べる
vector<long long> dp;
void rec(int v, long long w) {
    if (ps.ops[v] == 'a') {
        dp[ps.ids[v]] = w;
    }
    else if (ps.ops[v] == '+') {
        rec(ps.left[v], w);
        rec(ps.right[v], w);
    }
    else if (ps.ops[v] == '-') {
        rec(ps.left[v], w);
        rec(ps.right[v], mod(-w));
    }
    else if (ps.ops[v] == '*') {
        rec(ps.left[v], mod(w * ps.vals[ps.right[v]]));
        rec(ps.right[v], mod(w * ps.vals[ps.left[v]]));
    }
}

int main() {
    string S; cin >> S;
    int Q; cin >> Q;
    int N = 0;
    for (int i = 0; i < S.size(); ++i) if (S[i] == 'a') ++N;
    given.resize(N); for (int i = 0; i < N; ++i) cin >> given[i];
    
    long long base = ps.solve(S);
    dp.resize(N);
    rec(ps.root, 1);
    for (int q = 0; q < Q; ++q) {
        long long id, x; cin >> id >> x; --id;
        long long res = mod(base + (x - given[id]) * dp[id]);
        cout << res << endl;
    }
}