けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 2574 Magical Switches (JAG 模擬地区 2013 J) (900 点)

枝刈り探索が根本的に計算量改善することを示せることがある!!!
有名な例は最大独立集合問題に対する  O(1.466^{N}) のアルゴリズムかなと。

www.slideshare.net

問題へのリンク

問題概要

下図 (公式解説より) のように  3 \times W なグリッドが与えられる。グリッドの横方向には、幅  2 ずつの「アルファベットの描かれた塊」が  M 個分ある。

  • アルファベット小文字のところは通過できない
  • アルファベット大文字のところは通過できる

という感じになっていて、左から右へ通り抜けられるようにしたい。

スイッチが「アルファベット」に対応して 26 個ある。それぞれのスイッチを押すと、対応するアルファベットについて、小文字が通過できるようになり、大文字が通過できないようになる。

上手にスイッチを押すことで、左から右へと通り抜け可能にできるかどうかを判定し、できるならば具体的な方法を 1 つ示せ (押すスイッチの集合を出力)。

制約

  • テストケース数は  130 以下
  •  1 \le M \le 1000

考えたこと

ずっと前に解説だけ聞いて、「枝刈り探索で計算量が落ちる」という話に感動した覚えがあったけど、ちゃんと解いてなかったから解いてみた!

枝刈りが計算量を落とすカラクリ

まず、枝刈りで計算量が落ちるカラクリを。
たとえば部分和問題に対する全探索は、 N 個の整数があるときに、「 a_{0} を選ぶ」「 a_{0} を選ばない」という 2 つの場合に分岐して再帰的に解く全探索手法の計算量  T(N) は、

  •  T(N) \le 2T(N-1)

という感じになって、 T(N) = O(2^{N}) となるわけだ。しかしもし仮に、  a_{i} を選んだ場合には、 a_{i+1} は選ばなくてもいいよ、ということがわかるような構造が含まれていたならば

  •  T(N) \le T(N-1) + T(N-2)

という感じになる。この場合は実はこれを解くと  T(N) = O((\frac{1 + \sqrt{5}}{2})^{N}) となるのだ。これは  T(N) - T(N-1) - T(N-2) = 0 の特性多項式である  x^{2} + x + 1 = 0 を解くことで得られる。

3-SAT の場合

今回の問題は、実は 3-SAT に帰着して解くのだが、3-SAT も類似の探索が効く構造になっている。変数が  N 個あるとする。部分和問題と同じように、各変数について「T とするとき」「F とするとき」に分岐して解いていくことになる。たとえば

  • (a & b & c) | (...

という風になっていたとき、実は以下の 3 パターンのみ考えれば OK

  • a = T のとき
  • a = F で、b = T のとき
  • a = b = F で、c = T のとき

この 3 パターンに分岐するとなると、計算量は

  •  T(N) \le T(N-1) + T(N-2) + T(N-3)

となる。特性多項式である  x^{3} = x^{2} + x + 1 を解くと、調べるべき充足パターンは、なんと  O(1.8393^{N}) となるようだ!!!!!

今回の問題を 3-SAT に

それでは、元々の今回の問題を 3-SAT に帰着してみる。 2 \times 3 のマス目について考えると、以下の 8 つの条件をすべて満たすことと同値であることがわかる。

  •  (1, 1), (2, 1), (3, 1) のどれかは通過できる必要がある
  •  (1, 1), (2, 1), (3, 2) のどれかは通過できる必要がある
  •  (1, 1), (2, 2), (3, 1) のどれかは通過できる必要がある
  •  (1, 1), (2, 2), (3, 2) のどれかは通過できる必要がある
  •  (2, 1), (2, 1), (3, 1) のどれかは通過できる必要がある
  •  (2, 1), (2, 1), (3, 2) のどれかは通過できる必要がある
  •  (2, 1), (2, 2), (3, 1) のどれかは通過できる必要がある
  •  (2, 1), (2, 2), (3, 2) のどれかは通過できる必要がある

というわけで 3-SAT に帰着されるのだ!!!!!!

#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <set>
#include <algorithm>
using namespace std;


using CL = vector<pair<int,int> >;
bool rec(vector<int> &res, int k, const vector<CL> &cls) {
    if (k == cls.size()) return true;
    vector<pair<int,int> > rem;
    bool ok = false;
    for (int i = 0; i < cls[k].size(); ++i) {
        if (res[cls[k][i].first] == cls[k][i].second) {
            ok = true;
            break;
        }
        else if (res[cls[k][i].first] == -1) rem.push_back(cls[k][i]);
    }
    if (ok) return rec(res, k+1, cls);
    if (!ok && rem.empty()) return false;
    auto copy_res = res;
    for (int i = 0; i < rem.size(); ++i) {
        for (int j = 0; j < i; ++j) res[rem[j].first] = 1 - rem[j].second;
        res[rem[i].first] = rem[i].second;
        if (rec(res, k+1, cls)) return true;
    }
    res = copy_res;
    return false;
}

vector<int> solveSAT(int N, vector<CL> &cls) {
    for (int i = 0; i < (int)cls.size(); ++i) {
        map<int, set<int> > so;
        for (int j = 0; j < (int)cls[i].size(); ++j) {
            so[cls[i][j].first].insert(cls[i][j].second);
        }
        CL ncl;
        bool already = false;
        for (auto it : so) {
            if (it.second.size() == 2) already = true;
            ncl.push_back({it.first, *(it.second.begin())});
        }
        if (already) cls.erase(cls.begin() + i--);
        else cls[i] = ncl;
    }
    sort(cls.begin(), cls.end());
    for (int i = 0; i+1 < (int)cls.size(); ++i) {
        if (cls[i] == cls[i+1]) cls.erase(cls.begin() + i--);
    }
    vector<int> res(N, -1);
    if (rec(res, 0, cls)) return res;
    else return vector<int>();
}

int main() {
    int M;
    while (cin >> M, M) {
        vector<string> v(3);
        for (int i = 0; i < 3; ++i) cin >> v[i];
        
        vector<CL> cls;
        int iter = 0;
        for (int i = 0; i < M; ++i) {
            vector<int> pos({i*3+1, i*3+2});
            for (int bit = 0; bit < (1<<3); ++bit) {
                CL cl(3);
                for (int j = 0; j < 3; ++j) {
                    char c = (bit & (1<<j)) ? v[j][pos[1]] : v[j][pos[0]]; 
                    if (c >= 'a' && c <= 'z') cl[j] = {c - 'a', 1};
                    else cl[j] = {c - 'A', 0};
                }
                cls.push_back(cl);
            }
        }
        
        auto res = solveSAT(26, cls);
        if (res.empty()) cout << -1 << endl;
        else {
            vector<char> ans;
            for (int i = 0; i < 26; ++i) {
                if (res[i] == 1) ans.push_back('A' + i);
            }
            cout << ans.size() << endl;
            for (auto c : ans) cout << c << endl;
        }
    }
}