けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 129 F - Takahashi's Basics in Education and Learning (橙色, 600 点)

重たい。。。
でも例えば行列  A に対して

  •  A^{3} + A^{2} + A + E

の計算を行列累乗に帰着させるテクニックは、蟻本中級編の行列累乗のところに載っていたりする。それを膨らませると

  •  x^{5} + 2x^{4} + 3x^{3} + 4x^{2} + 5x + 6 ({\rm mod} M)

みたいな計算も、行列累乗でできそうという気持ちには確かになるよね。。。(なったけど昨日間に合わなかった)

問題へのリンク

問題概要

初項が  A、交差が  B、長さが  L の等差数列を十進法表記で表して連結してできる数を考える。例えば  A = 3、B = 4、L = 5 なら、 3, 7, 11, 15, 19 なので  317111519 になる。

この数を  M で割ったあまりを求めよ。ただし  M は素数とは限らない。

制約

  •  1 \le L, A, B \lt 10^{18}
  •  2 \le M \le 10^{9}
  • 等差数列の要素はすべて  10^{18} 未満

考えたこと

 A = 3, B = 4, L = 5 の例でいうと、

  •  3 \times (10^{7} + 10^{6} + 10^{4} + 10^{2} + 10^{0})
  •  4 \times (10^{6} + 10^{4} + 10^{2} + 10^{0})
  •  4 \times (10^{4} + 10^{2} + 10^{0})
  •  4 \times (10^{2} + 10^{0})
  •  4 \times (10^{0})

の総和をとる問題だと思うことができて、そういう方針で考えることにした。そして、基本的には等差数列において桁数ごとに考えてあげて  d = 18, 17, ..., 1 に対して  d 桁の整数 ( 10^{d-1} 以上  10^{d}-1 以下) が  C_d 個あるときに、 x = 10^{d} として、

  •  d 桁より大きい整数の分について + E_d 桁分あるとして、
    • その分についての項和を  F_{d} として
  • 以下を合算して  B をかけたもの
    •  (x^{C_{d}} + 2x^{C_{d}-1} + 3x^{C_{d}-2} + \dots + C_{d}) \times x^{E_{d}}
    •  C_{d} \times F_{d}

を各  d に対して合計していけばよい。最後に初項  A の分については補正する。

多項式計算

今回

  •  x^{4} + x^{3} + x^{2} + x + 1
  •  x^{4} + 2x^{3} + 3x^{2} + 4x + 5

といった計算を高速に実行する必要性にかられている。これらは逆元を使えるなら明示的な式で書けるのだが、今回は任意 MOD なのでそうはいかない。でもその場合にも行列累乗に帰着させる方法は蟻本中級編にも載っている!!!!!!

前者は

x 1
0 1

後者は

x 1 1
0 1 1
0 0 1

という行列を累乗していけばよい。

#include <iostream>
#include <vector>
using namespace std;

// px + r (x >= 0) で表せる整数のうち、x 以上となる最小の整数
long long lower_amari(long long p, long long r, long long x) {
    if (r >= x) return r;
    return (x - r + p-1) / p * p + r;
}

// matrix
int MOD;
struct Matrix {
    vector<vector<long long> > val;
    Matrix(int n, int m, long long x = 0) : val(n, vector<long long>(m, x)) {}
    void init(int n, int m, long long x = 0) {val.assign(n, vector<long long>(m, x));}
    size_t size() const {return val.size();}
    inline vector<long long>& operator [] (int i) {return val[i];}
};

Matrix operator * (Matrix A, Matrix B) {
    Matrix R(A.size(), B[0].size());
    for (int i = 0; i < A.size(); ++i) 
        for (int j = 0; j < B[0].size(); ++j)
            for (int k = 0; k < B.size(); ++k) 
                R[i][j] = (R[i][j] + A[i][k] * B[k][j] % MOD) % MOD; 
    return R;
}

Matrix pow(Matrix A, long long n) {
    Matrix R(A.size(), A.size());
    for (int i = 0; i < A.size(); ++i) R[i][i] = 1;
    while (n > 0) {
        if (n & 1) R = R * A;
        A = A * A;
        n >>= 1;
    }
    return R;
}

// mod
long long modpow(long long a, long long n) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % MOD;
        a = a * a % MOD;
        n >>= 1;
    }
    return res;
}

// x^(n-1) + x^(n-2) + ... + 1
long long ser(long long x, long long n) {
    if (n == 0) return 0;
    Matrix M(2, 2);
    M[0][0] = x % MOD; M[0][1] = 1;
    M[1][0] = 0; M[1][1] = 1;
    auto P = pow(M, n-1);
    return (P[0][0] + P[0][1]) % MOD;
}

// x^(n-1) + 2x^(n-2) + 3x^(n-3) + ... + n
long long ser2(long long x, long long n) {
    if (n == 0) return 0;
    Matrix M(3, 3);
    M[0][0] = x % MOD; M[0][1] = 1; M[0][2] = 1;
    M[1][0] = 0; M[1][1] = 1; M[1][2] = 1;
    M[2][0] = 0; M[2][1] = 0; M[2][2] = 1;
    auto P = pow(M, n-1);
    return (P[0][0] + P[0][1] + P[0][2]) % MOD;
}

long long solve(long long L, long long A, long long B, long long M) {
    MOD = M;
    long long C = A + B * (L-1);
    long long res = 0;
    long long curd = 0;
    long long sum = 0;
    long long fac = 1;
    for (long long d = 18; d >= 1; --d) {
        // まずは d 桁の個数を数える
        long long beki = 1;
        for (int i = 0; i < d - 1; ++i) beki *= 10;
        long long low = beki, high = beki * 10 - 1;

        high = min(high, C);
        if (high < A || C < low) continue;

        long long low_kou = lower_amari(B, A, low);
        long long up_kou = lower_amari(B, A, high + 1);
        long long num = (up_kou - low_kou) / B;

        // それを使って計算
        long long x = modpow(10LL, d);
        long long alr = num % MOD * sum % MOD;
        long long add = ser2(x, num) * fac % MOD;
        res = (res + (add + alr) % MOD * (B % MOD) % MOD) % MOD;
        sum = (sum + fac * ser(x, num) % MOD) % MOD;
        fac = fac * modpow(modpow(10LL, d), num) % MOD;
    }
    long long hosei = ((A - B) % M + M) % M * sum % M;
    res = (res + hosei) % M;
    return res;   
}

int main() {
    long long L, A, B, M;
    while (cin >> L >> A >> B >> M) cout << solve(L, A, B, M) << endl;
}