けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

競プロ典型 90 問 005 - Restricted Digits(4D, ★7)

きたまさ法によく似たタイプの DP ダブリング高速化!

ただかなり難しい問題だと思うので、004 まで順調に解いていた方が、この問題で挫折しないように注意!!!

無理に解こうとせずに飛ばすのも一案だと思います......

問題概要

 K 種類の数字  c_{1}​, c_{2}​, \dots, c_{K}​ のみを使うことで作れる  N 桁の正の整数のうち、 B の倍数は何個あるか、 1000000007 で割ったあまりを答えてください。

制約

  •  1 \le K \le 9
  •  1 \le N \le 10^{18}
  •  2 \le B \le 1000

まずは DP

今回の問題のように、「〜という数字を使って整数を作り、それが  B の倍数になるようにする」というタイプの問題では、DP が有効な印象があります。

drken1215.hatenablog.com

そのような DP では、整数というものを、

  •  10 倍する
  •  a を足す

という  2 つの手続きを繰り返すことで作り上げるものと考えます。たとえば " 4649" という整数は

  • 整数  0 からスタートして
  •  10 倍して  4 を足すと、 4 になる
  •  10 倍して  6 を足すと、 46 になる
  •  10 倍して  4 を足すと、 464 になる
  •  10 倍して  9 を足すと、 4649 になる

というように作れます。これを踏まえて今回の問題は、次の配列 dp を考えることで (計算量を無視すれば) 解けます。


dp[i][r] ← 所定の  K 種類の数字のみを用いた  i 桁の整数であって、 B で割ったあまりが  r であるようなものの個数


このように、「 B の倍数の個数」を聞かれているときに、「 B で割ったあまり」を定義に含めた DP を考えると有効なことは多々あります。

この DP の遷移式は次のように考えられます。

  • dp[i + 1][(r * 10 + c[k]) % B] += dp[i][r]

この時点で計算量は  O(NBK) になります。このままでは当然 TLE。

行列累乗することで  O(B^{3} \log N) にはなりますが、 B \le 1000 なのでそれでも間に合いません。このようなとき、ダブリング解法を考えることが有効な場合があるようです。

ダブリング解法を適用すると、 O(B^{2} \log N) になります。

ダブリングの考え方

今回は配列 dp[N] (dp[N][0], dp[N][1], ..., dp[N][B-1] という  B 個の要素からなる配列) を求めたいということになります。

上の「dp[i + 1][(r * 10 + c[k]) % B] += dp[i][r]」という漸化式は、配列 dp[0] からスタートして、dp[0]dp[1]dp[2] → ... → dp[N] という順に計算していくものと解釈できます。しかしこのままでは  O(N) 回のステップを要することになります。

しかし今回は、ダブリングすることで、 O(\log N) ステップの更新で dp[N] が求められるのです!

そこで用いられるアイデアは、繰り返し二乗法と同じものです。繰り返し二乗法とは、たとえば  3^{64} を計算したいときに「 3 64 回かける」とやるのではなく、

  •  3^{1} からスタートして
  •  3^{2} (=  3^{1} \times 3^{1})
  •  3^{4} (=  3^{2} \times 3^{2})
  •  3^{8} (=  3^{4} \times 3^{4})
  •  3^{16} (=  3^{8} \times 3^{8})
  •  3^{32} (=  3^{16} \times 3^{16})
  •  3^{64} (=  3^{32} \times 3^{32})

という順に計算することで、わずか  6 回のかけざんで求められます!

この方法だと一見  3^{N} を計算するのに  N 2 の冪乗である必要があるように思われるかもしれません。しかしたとえば  3^{100} を計算するときにも、

 100 = 64 + 32 + 4

であることに着目して、

 3^{100} = 3^{64} \times 3^{32} \times 3^{4}

というように計算できます。念のために、一般に通用する方法を整理しましょう。 100 は二進法で表すと  1100100 となります。このことから

 100 = 2^{2} + 2^{5} + 2^{6}

であることが導かれます。一般に  3^{N} を計算する場合には、 N を二進法で表すことで求められます。このような方法をダブリングと呼ぶことがあります。

今回の DP の場合

以上の  3^{100} を計算したような方法を使えるためには、


配列 dp[i] と配列 dp[j] とから、配列 dp[i + j] を計算する


ことが高速にできればよいことになります。これさえできれば、 3^{N} を計算するときのようなダブリング手法によって、配列 dp[N] O(\log N) 回のステップで計算できます。

ここで、dp[i + j] とは  i+j 桁の数を考えていることになります。これを「前半の  i 桁分」と「後半の  j 桁分」とに分けて考えてみましょう。

  • 前半  i 桁分を  B で割ったあまりが  p (dp[i][p] 通りあります)
  • 後半  j 桁分を  B で割ったあまりが  q (dp[j][q] 通りあります)

であるとすると、そのような  i+j 桁の整数を  B で割ったあまりは

 (p \times 10^{j} + q) %  B

となります。つまり、


dp[i + j][(p * tj + q) % B] += dp[i][p] * dp[j][q]


という遷移で表せます。ここで  10^{j} B で割ったあまりを tj と書いています。これはまさに、「配列 dp[i] と配列 dp[j] とから配列 dp[i + j] を計算する遷移式」です。そしてこの遷移は  O(B^{2}) の計算量でできます。

よって全体としては、ダブリングによって  O(\log N) 回の遷移を行うことになりますので、 O(B^{2} \log N) の計算量で解けます。

コード (C++ と PyPy3)

C++

#include <iostream>
#include <vector>
using namespace std;

// MOD
constexpr int MOD = 1000000007;

// N の対数
constexpr int LOG = 62;

int main() {
    // 入力
    long long N, B, K;
    cin >> N >> B >> K;
    vector<int> C(K);
    for (auto& ci : C) cin >> ci;

    // dp[i] と dp[j] を掛け合わせて dp[i+j] を得る処理
    // tj: 10^j を B で割ったあまり
    auto mul = [&](const vector<long long>& dpi,
                   const vector<long long>& dpj,
                   long long tj) -> vector<long long> {
        vector<long long> res(B, 0);
        for (long long p = 0; p < B; ++p) {
            for (long long q = 0; q < B; ++q) {
                res[(p * tj + q) % B] += dpi[p] * dpj[q];
                res[(p * tj + q) % B] %= MOD;
            }
        }
        return res;
    };

    // ten[i]: 10^(2^i) を B で割ったあまり
    vector<long long> ten(LOG, 10);
    for (int i = 1; i < LOG; ++i) {
        ten[i] = (ten[i - 1] * ten[i - 1]) % B;
    }

    // dp[2^i][r] を doubling[i][r] で書くことにする
    vector<vector<long long>> doubling(LOG, vector<long long>(B, 0));

    // 初期化 (doubleing[0] := dp[1])
    for (int k = 0; k < K; ++k) {
        doubling[0][C[k] % B] += 1;
    }

    // ダブリング
    for (int i = 1; i < LOG; ++i) {
        doubling[i] = mul(doubling[i - 1], doubling[i - 1], ten[i - 1]);
    }

    // ダブリングした結果をもとに答えを求める
    vector<long long> res(B, 0);
    res[0] = 1;
    for (int i = 0; i < LOG; ++i) {
        // N を二の冪乗の積で表すときに、2^i を含むかどうか
        if (N & (1LL << i)) {
            res = mul(res, doubling[i], ten[i]);
        }
    }
    cout << res[0] << endl;
}   

Python3 (PyPy3)

# MOD
MOD = 1000000007

# N の対数
LOG = 62

# 入力
N, B, K = map(int, input().split())
C = list(map(int, input().split()))

# dp[i] と dp[j] を掛け合わせて dp[i+j] を得る関数
# tj: 10^j を B で割ったあまり
def mul(dpi, dpj, tj):
    res = [0] * B
    for p in range(B):
        for q in range(B):
            res[(p * tj + q) % B] += dpi[p] * dpj[q]
            res[(p * tj + q) % B] %= MOD
    return res

# ten[i]: 10^(2^i) を B で割ったあまり
ten = [10] * LOG
for i in range(1, LOG):
    ten[i] = (ten[i - 1] * ten[i - 1]) % B

# dp[2^i][r] を doubling[i][r] と書くことにする
doubling = [[0] * B for _ in range(LOG)]

# 初期化 (doubling[0] = dp[1])
for k in range(K):
    doubling[0][C[k] % B] += 1

# ダブリング
for i in range(1, LOG):
    doubling[i] = mul(doubling[i - 1], doubling[i - 1], ten[i - 1])

# ダブリングした結果をもとに答えを求める
res = [0] * B
res[0] = 1
for i in range(LOG):
    if N & (1 << i):
        res = mul(res, doubling[i], ten[i])
print(res[0])