この手の
- 周期性を利用する
- ダブリングする
のどちらでも解けるタイプの問題、最近めっちゃ多いね。
問題概要
を で割ったあまりを で表す。
整数 が与えられる。以下で定まる漸化式の最初の 項の総和を求めよ。
制約
考えたこと
最初誤読してしまった。単純に
を計算すればいいのかと思ってしまった。しかしそうではなくて、
- ...
の総和を求めるというのが正しい。愚直に 項すべて求めたのでは間に合わないので高速化する必要がある。例によって大きく 2 つのやり方がある。
解法 (1):周期性に注目
, , , ... と計算していくと、いつかはかならず下図みたいに循環する。
ここで注意したいことは、かならずしも「スタート地点」に戻ってくるとは限らない。しかしそれでも、「かつて来たことのある地点」に再び戻ってくる瞬間があるのだ。なぜなら、 で割ったあまりというのはそもそも の 種類の値しかとりえないからだ。具体的には 項目まで来たとき、かならずその中に「等しいペア」は存在する。
そして、「かつて来たことのある地点」に初めて到達する部分が重要。そこから先は同じサイクルを延々と未来永劫ずっと繰り返すことになる。よって次のように求めることができる。
- 最初は愚直に , , , ... を で割ったあまりを計算していく
- その過程で 項目まで到達した場合はその時点での答えをリターンする
- 項目ではじめて「かつて来たことのある地点」に到達した場合は次のようになることがわかる
- N -= a とする
- サイクルの項数を c として、q = N / c、r = N % c とすると
- サイクルは q 週して、追加で r ステップ進む
計算量は 。
#include <bits/stdc++.h> using namespace std; long long solve() { long long N, X, M; cin >> N >> X >> M; vector<int> ord(M, -1); // かつて来た地点を求める vector<long long> rireki, syu; long long res = 0; for (int n = 0; n < N; ++n) { // かつて来た地点に戻ったら if (ord[X] != -1) { int p = ord[X]; for (long long i = p; i < n; ++i) syu.push_back(rireki[i]); break; } ord[X] = n; rireki.push_back(X); res += X; X = (X * X) % M; } N -= rireki.size(); // 戻る前に N 項目に到達した場合 if (N == 0) return res; // 周期の累積和をとる vector<long long> sum(syu.size() + 1, 0); for (int i = 0; i < syu.size(); ++i) sum[i+1] = sum[i] + syu[i]; // 周期を q 週して r あまる long long q = N / syu.size(); long long r = N % syu.size(); res += sum[syu.size()] * q + sum[r]; return res; } int main() { cout << solve() << endl; }
解法 (2):ダブリング
この手の N ステップ分の挙動を解析する問題は、ダブリングでも解けることで有名。
- nex[ ][ ] := から ステップ進んだ先の値
- sum[ ][ ] := から 項分の総和
という値をダブリングによって求めれば OK。計算量は 。
#include <bits/stdc++.h> using namespace std; long long solve() { long long N, X, M; cin >> N >> X >> M; const int MAXD = 55; vector<vector<long long>> nex(MAXD+1, vector<long long>(M, -1)), sum(MAXD+1, vector<long long>(M, 0)); for (long long r = 0; r < M; ++r) { nex[0][r] = r * r % M; sum[0][r] = r; } for (int p = 0; p < MAXD; ++p) { for (int r = 0; r < M; ++r) { nex[p+1][r] = nex[p][nex[p][r]]; sum[p+1][r] = sum[p][r] + sum[p][nex[p][r]]; } } long long res = 0; int cur = X; for (int p = MAXD; p >= 0; --p) { if (N & (1LL<<p)) { res += sum[p][cur]; cur = nex[p][cur]; } } return res; } int main() { cout << solve() << endl; }