けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

桁 DP の思想 〜 K 以下の整数を走査するとはどういうことか 〜

K 以下の整数を分類する

今日の ABC D 問題 で話題になったので書いてみます。競プロで


  • 非負整数  X X \le K の範囲を動くときの、〜〜〜の最大値を求めよ

  • 非負整数  X X \le K の範囲を動くときの、〜〜〜という条件を満たすものは何通りあるか


という形をした問題は非常に多いです。この種のタイプの問題に対して「思考停止で桁 DP」は大変有力ですが、そもそも  K 以下の整数というものがどんなものなのかを理解しておくことは有益だと思います。

例えば  K = 8357 の場合については、以下を見れば一目瞭然です。

f:id:drken1215:20190204005438p:plain

噂の D 問題は、この場合分けが頭にあれば自然に解ける問題でした。

桁 DP とは

桁 DP とは、まさに「 K 以下の整数が上記のように分類できることを利用した走査法」であると言えます。

上記のイメージを持って、桁 DP を学ぶとスッと頭に入る気がします。まずそもそも桁 DP を適用できるためには、 X を定めたときの全体スコアを

  •  X の上から 0 桁目を決めたたときのスコア
  •  X の上から 1 桁目を決めたたときのスコア
  • ...

の総和 (かその亜種) になるように表現してあげる必要があります。ABC 117 D であれば、 X などを二進法で考えたときに、 X の桁数を  s として、

  •  X の上から i 桁目を 1 にするなら、 A_1, A_2, \dots, A_N をそれぞれ i 桁目をビット反転したときに、i 桁目が 1 となるものの個数を  a として、 2^{s-1-i} ×  a
  •  X の上から i 桁目を 0 にするなら、 A_1, A_2, \dots, A_N をそれぞれ i 桁目をビット反転しないときに i 桁目が 1 となるものの個数を  b として、 2^{s-1-i} ×  b

という風にしてあげればよいです。これを各 i に対して合計したものが最終スコアになります。そしてそれを元にして、桁 DP はベースの考え方としては


dp[ i ] := 上から i 桁目まで決めたときの暫定スコアの最大値とか最小値とか


とするわけですが、例えば  K = 8357 のとき、

  • dp[ 2 ] を最大にする  X が 82 とかであれば、これは 83 より小さいので、82 の次の桁の数は何でもいい

  • dp[ 2 ] を最大にする  X が 83 であれば、83 の次の数は 5 を超えてはならない

という風に状況が分かれてしまい、困ってしまいます。そこで多くの場合


dp[ i ][ smaller ] :=  X を上から i 桁目まで決めたとき、それを  K の上から i 桁目までと比較したときに、

  • smaller = 0 のとき、 K とちょうど一致する場合
  • smaller = 1 のとき、 K よりも小さくなっている場合

についての、スコアの最大値


という風にしてあげます。dp[ i ][ smaller ] から dp[ i + 1 ][ smaller ] への遷移を考えるとき、

  • smaller = 1 からは smaller = 1 にしか遷移せず、i 桁目は何を選んでもよいので全体の中から最適なものを選ぶ

  • smaller = 0 から smaller = 0 への遷移は、K の i 桁目に忠実に合わせる (選択の余地はない)

  • smallse = 0 から smaller = 1 への遷移は、K の i 桁目より小さい範囲から最適なものを選ぶ

という風にしてあげればいいです。あとは問題に応じて DP 遷移式の詳細を組み立てることになります。具体的な問題は

に詳しいです!ここに載っている問題たちを解くと、桁DPは思考停止で書けるようになっている気がします。

ここでは ABC 117 D - XXOR についてのコードを載せてみます。 X A_i たちを仮想的に二進法で 50 桁だとしています (50 桁に満たないものは左から 0 埋めします)。

#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }

const int MAX_DIGIT = 50;
long long dp[100][2]; // dp[上から i 桁まで][ smaller ]

int main() {
    int N;
    long long K;
    cin >> N >> K;
    vector<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    
    memset(dp, -1, sizeof(dp)); // DP 配列を -1 で初期化
    dp[0][0] = 0; // 初期条件
    for (int d = 0; d < MAX_DIGIT; ++d) {
        long long mask = 1LL<<(MAX_DIGIT - d - 1);
        
        // A で元々 d 桁目にビットが立っているものの個数
        int num = 0;
        for (int i = 0; i < N; ++i) if (A[i] & mask) ++num;
        
        // X の d 桁目を 0, 1 にしたときのコスト
        long long cost0 = mask * num;
        long long cost1 = mask * (N - num);
        
        // smaller -> smaller
        if (dp[d][1] != -1) {
            chmax(dp[d+1][1], dp[d][1] + max(cost0, cost1)); // 0 でも 1 でも自在に大きい方
        }
        
        // exact -> smaller
        if (dp[d][0] != -1) {
            if (K & mask) { // K の d 桁目が 1 だったら、X の d 桁目は 0 にする
                chmax(dp[d+1][1], dp[d][0] + cost0);
            }
        }
        
        // exact -> exact (K にぴったり合わせる)
        if (dp[d][0] != -1) {
            if (K & mask) chmax(dp[d+1][0], dp[d][0] + cost1);
            else chmax(dp[d+1][0], dp[d][0] + cost0);
        }
    }
    cout << max(dp[MAX_DIGIT][0], dp[MAX_DIGIT][1]) << endl;
}