構文解析、超絶苦手系だけど苦手とばかり言っていられない。
問題概要
(a)*a+((a+(a*(a))-(a)*a+a*a))*a
のような文字列 が与えられる。各 a に入るデフォルトの数値 が与えられている。今 個のクエリが来て、各クエリは
- : 個目の a を で置き換えたときの数式の値を 1000000007 で割ったあまりを求めよ
というものである。
制約
- の長さは 200000 以下
解法
まず重要な観察として、
- 各 a の値に対する、全体結果の値の応答は線形である
したがって、各 a を 1 増やしたときに全体の応答がどう変化するかを求められればよい。
まず与えられた文字列を構文解析する。そうすると、こんな感じの二分木を作ることができる:
ここから木 DP っぽいことをやっていく。
dp[v] := v を根とする部分木について、v の値が 1 増えたとしたら全体結果は幾つ増えるか
という値を持っておく。通常の木 DP では葉から根へと伝播していくが、今回は向きが反対で、根から葉へと伝播していく感じになる。v の演算子によって左右の子ノードの dp 値がどうなるかを考えていく。
例えば dp[v] = w で v の演算子が「-」のとき、v の右ノードが 1 増えれば、v のノードが -1 増えることになるので、全体の値が -w 増えることになる。したがって、dp[right[v]] = -w である。
- dp[v] = w であったとき
v の演算子 | 左ノード | 右ノード |
---|---|---|
+ | w | w |
- | w | -w |
* | w*(右ノードのデフォルト値) | w*(左ノードのデフォルト値) |
こういう感じで伝播していけばよい。各クエリに対する答えは
(ルートのデフォルト値) + (x - 指定ノードのデフォルト値) * dp[指定ノード]
という感じになる。
#include <iostream> #include <vector> #include <string> using namespace std; const int MOD = 1000000007; inline long long mod(long long n) { n %= MOD; if (n < 0) n += MOD; return n; } vector<long long> given; template<class T> struct Parser { // results int root; // vals[root] is the answer vector<T> vals; // value of each node vector<char> ops; // operator of each node ('a' means leaf values) vector<int> left, right; // the index of left-node, right-node // member to solve each problem vector<int> ids; // 何番目の a か int id = 0; // generate nodes int newnode(char op, int lp, int rp, T val = 0) { ops.push_back(op); left.push_back(lp); right.push_back(rp); if (op == 'a') { vals.push_back(val); ids.push_back(id++); } else { if (op == '+') vals.push_back(mod(vals[lp] + vals[rp])); else if (op == '-') vals.push_back(mod(vals[lp] - vals[rp])); else if (op == '*') vals.push_back(mod(vals[lp] * vals[rp])); ids.push_back(-1); } return (int)vals.size() - 1; } // main solver T solve(const string &S) { int p = 0; root = expr(S, p); return vals[root]; } // parser int expr(const string &S, int &p) { int lp = factor(S, p); while (p < (int)S.size() && (S[p] == '+' || S[p] == '-')) { char op = S[p]; ++p; int rp = factor(S, p); lp = newnode(op, lp, rp); } return lp; } int factor(const string &S, int &p) { int lp = value(S, p); while (p < (int)S.size() && (S[p]== '*' || S[p] == '/')) { char op = S[p]; ++p; int rp = value(S, p); lp = newnode(op, lp, rp); } return lp; } int value(const string &S, int &p) { if (S[p] == '(') { ++p; // skip '(' int lp = expr(S, p); ++p; // skip ')' return lp; } else { /* each process */ while (S[p] == 'a') ++p; T val = given[id]; return newnode('a', -1, -1, val); } } }; Parser<long long> ps; // それぞれの a が 1 変わったらどうなるかを調べる vector<long long> dp; void rec(int v, long long w) { if (ps.ops[v] == 'a') { dp[ps.ids[v]] = w; } else if (ps.ops[v] == '+') { rec(ps.left[v], w); rec(ps.right[v], w); } else if (ps.ops[v] == '-') { rec(ps.left[v], w); rec(ps.right[v], mod(-w)); } else if (ps.ops[v] == '*') { rec(ps.left[v], mod(w * ps.vals[ps.right[v]])); rec(ps.right[v], mod(w * ps.vals[ps.left[v]])); } } int main() { string S; cin >> S; int Q; cin >> Q; int N = 0; for (int i = 0; i < S.size(); ++i) if (S[i] == 'a') ++N; given.resize(N); for (int i = 0; i < N; ++i) cin >> given[i]; long long base = ps.solve(S); dp.resize(N); rec(ps.root, 1); for (int q = 0; q < Q; ++q) { long long id, x; cin >> id >> x; --id; long long res = mod(base + (x - given[id]) * dp[id]); cout << res << endl; } }