けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 182 C - To 3 (灰色, 300 点)

これ灰色ってマジか!!

問題概要

どの桁の値も 0 でないような正の整数  N が与えられる。 N に含まれるいくつかの数値を除去することで、 N が 3 の倍数となるようにしたい。

除去すべき数値の個数の最小値を求めよ (最初から 3 の倍数の場合は 0 個)。ただし、すべての数値を除去することはできない。不可能な場合は -1 を出力せよ。

制約

  •  1 \le N \le 10^{18}

解法 (1):整数論的考察で解く

まず、次の有名事実がある。


整数  N を 3 で割ったあまり = 整数  N の各桁の和を 3 でわったあまり


たとえば  N = 1234567 を 3 で割ったあまりは

(1 + 2 + 3 + 4 + 5 + 6 + 7) % 3 = 28 % 3 = 1

になるというわけだ。特に、「3 の倍数かどうかの判定」は、「各桁の和が 3 の倍数かどうか」によって判定できる。このことの証明はたとえばこのサイトなどを参照。

場合分け

以上の考察に基づいて場合分けする解法が考えられる。

  •  S := 各桁の和を 3 で割ったあまり
  •  a := 各桁のうち、3 の倍数が何個あるか
  •  b := 各桁のうち、3 で割って 1 あまるものが何個あるか
  •  c := 各桁のうち、3 で割って 2 あまるものが何個あるか

このとき、 0a + b + 2c を 3 で割ったあまりが、 S に一致することになる。

 S = 0 のとき

この場合はなにもしなくてよいので、答えは 0 となる

 S = 1 のとき

まず、各桁のうち、3 の倍数となっているものについては、それを除去しても「 N を 3 でわったあまり」は変わらないことに注意しよう。よって基本的には、 b c を減らしていくことを考える。

もし  b \ge 1 ならば、 b を 1 減らせば条件を満たすので、答えは 1 (ただし 1 桁の整数の場合は -1)。

 b = 0 のときは、 c \ge 2 ならば、 c を 2 減らせば条件を満たすので、答えは 2 (ただし 2 桁の整数の場合は -1)。

それ以外は不可能なので -1 となる。

 S = 2 のとき

同様に

  •  c \ge 1 ならば、答えは 1 (ただし 1 桁の整数の場合は -1)
  •  c = 0 かつ  b \ge 2 ならば、答えは 2 (ただし 2 桁の整数の場合は -1)
  • それ以外は -1

コード

場合分けが多いので、慎重にテストした。計算量は  O(\log N)

#include <bits/stdc++.h>
using namespace std;

long long solve(long long N) {
    long long S = 0, D = 0;
    vector<long long> A(3, 0);
    while (N) {
        int d = N % 10;
        A[d % 3]++, S += d, ++D;
        N /= 10;
    }

    if (S % 3 == 0) return 0;
    else if (S % 3 == 1) {
        if (A[1] >= 1 && D > 1) return 1;
        else if (A[2] >= 2 && D > 2) return 2;
        else return -1;
    }
    else if (S % 3 == 2) {
        if (A[2] >= 1 && D > 1) return 1;
        else if (A[1] >= 2 && D > 2) return 2;
        else return -1;
    }
}

int main() {
    long long N;
    cin >> N;
    cout << solve(N) << endl;
}

解法 (2):bit 全探索

解法 (1) は場合分け漏れが怖いので、bit 全探索する方が安全かもしれない。高々 19 桁しかないので、どの桁を除去するのかを bit 全探索するのだ。

bit 全探索の方法については、以下の記事に書いた。

drken1215.hatenablog.com

計算量は  A = \log_{10} N として、 O(A2^{A}) となる。

#include <bits/stdc++.h>
using namespace std;

long long solve(long long N) {
    vector<int> A;
    while (N) {
        A.push_back(N % 10);
        N /= 10;
    }

    const int INF = 1<<29;
    int res = INF;
    for (int bit = 0; bit < (1 << A.size()) - 1; ++bit) {
        int sum = 0, con = 0;
        for (int i = 0; i < A.size(); ++i) {
            if (bit & (1<<i)) ++con;
            else sum += A[i];
        }
        if (sum % 3 == 0) res = min(res, con);
    }
    return (res < INF ? res : -1);
}

int main() {
    long long N;
    cin >> N;
    cout << solve(N) << endl;
}