けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 2679 Decoding Ancient Messages (JAG 春コン 2014 C) (700 点)

多倍長整数を活用した、重み付き二部マッチング

問題概要

 N \times N のグリッドが与えられる。各マスにはアルファベット (英大文字と英小文字) が描かれている。今、次の条件を満たすように  N 個の文字を抜き出す

  • 各行から選ぶマスはちょうど 1 個
  • 各列から選ぶマスはちょうど 1 個

そして、 N 個の文字を並び替えて連結して  N 文字の文字列を作る。こうして作られる文字列のうち、辞書順最小のものを求めよ。

制約

  •  1 \le N \le 50

考えたこと

長さが等しくて、アルファベット順にソートされた文字列の辞書順比較は、次のようにして実現できる。まず、文字列を頻度分布に直す。

"AAAABDD" → (4, 1, 0, 2, 0, 0, ...)

このようにしたとき、元の文字列  S, T と、頻度分布に変換したベクトル  f(S), f(T) とでは

 S \lt T f(S) \gt f(T)

という関係が成立する。よって  N! 通りの選択肢の中から

  • アルファベット 'A' が極力多くなるように選び
  • そのうちアルファベット 'B' が極力多くなるように選び
  • ...

という Greedy な考え方で選ぶことになる。これを実現するためには、 x を十分大きな整数として

  • アルファベット 'A' を、数値  x^{51} に対応
  • アルファベット 'B' を、数値  x^{50} に対応
  • ...
  • アルファベット 'z' を、数値  x^{0} に対応

として、選んだ  N 個のアルファベットに対応する数値の総和が「最大」となるようにすれば OK。

多倍長の二部マッチングへ

ここから先は通常の重み付き二部マッチングと考えて、最小費用流問題に帰着すれば OK。

ただし、総和が「最大」ではなく、「最小」になるようにする必要があるので、十分大きな整数  M を用意して、

  • アルファベット 'A' を、数値  M - x^{51} に対応
  • アルファベット 'B' を、数値  M - x^{50} に対応
  • ...
  • アルファベット 'z' を、数値  M - x^{0} に対応

という風にして、最小費用流問題に帰着すれば OK。復元はちょっと面倒。。。

コード

計算量は  O(N^{3}) ではあるものの、定数ははちゃめちゃ重たい。

#include <bits/stdc++.h>
using namespace std;

const int DEFAULT_SIZE = 125;
struct bint : vector<long long> {
    static const long long BASE = 100000000;
    static const int BASE_DIGIT = 8;
    int sign;

    // constructor
    bint(long long num = 0) : vector<long long>(DEFAULT_SIZE, 0), sign(1) {
        if (num < 0) sign = -1, num = -num;
        (*this)[0] = num; 
        this->normalize();
    }
    bint(int size, long long num) : vector<long long>(size, num), sign(1) {}
    bint& normalize() {
        long long c = 0;
        bool exist = false;
        for (int i = 0;; ++i) {
            if (i >= this->size()) this->push_back(0);
            if ((*this)[i] < 0 && i+1 >= this->size()) this->push_back(0);
            while ((*this)[i] < 0) {
                (*this)[i+1] -= 1;
                (*this)[i] += BASE;
            }
            long long a = (*this)[i] + c;
            (*this)[i] = a % BASE;
            if ((*this)[i]) exist = true;
            c = a / BASE;
            if (c == 0 && i == this->size()-1) break;
        }
        if (!exist) sign = 1;
        return (*this);
    }
    friend bint abs(const bint &x) {
        bint z = x;
        if (z.sign == -1) z.sign = 1;
        return z;
    }

    // operation
    bint operator - () const {
        bint res = *this;
        bool allzero = true;
        for (int i = 0; i < this->size(); ++i) {
            if (res[i] != 0) {
                allzero = false;
                break;
            }
        }
        if (!allzero) res.sign = -res.sign; 
        return res; 
    }
    bint& operator += (const bint& r) {
        while (size() < r.size()) this->emplace_back(0);
        if (sign == r.sign) {
            for (int i = 0; i < r.size(); ++i) (*this)[i] += r[i];
        }
        else {
            if (sign == 1 && abs(*this) < abs(r)) sign = -1;
            else if (sign == -1 && abs(*this) <= abs(r)) sign = 1;
            if (abs(*this) >= abs(r)) {
                for (int i = 0; i < r.size(); ++i) (*this)[i] -= r[i];
            }
            else {
                for (int i = 0; i < size(); ++i) (*this)[i] = -(*this)[i];
                for (int i = 0; i < r.size(); ++i) (*this)[i] += r[i];
            }
        }
        return this->normalize();
    }
    bint& operator -= (const bint& r) {
        while (size() < r.size()) this->emplace_back(0);
        if (sign == -r.sign) {
            for (int i = 0; i < r.size(); ++i) (*this)[i] += r[i];
        }
        else {
            if (sign == 1 && abs(*this) < abs(r)) sign = -1;
            else if (sign == -1 && abs(*this) <= abs(r)) sign = 1;
            if (abs(*this) >= abs(r)) {
                for (int i = 0; i < r.size(); ++i) (*this)[i] -= r[i];
            }
            else {
                for (int i = 0; i < size(); ++i) (*this)[i] = -(*this)[i];
                for (int i = 0; i < r.size(); ++i) (*this)[i] += r[i];
            }
        }
        return this->normalize();
    }
    bint& operator *= (long long r) {
        if ( (sign == 1 && r >= 0) || (sign == -1 && r < 0) ) sign = 1;
        else sign = -1;
        if (r < 0) r = -r;
        for (int i = 0; i < size(); ++i) (*this)[i] *= r;
        return this->normalize();
    }
    bint& operator *= (const bint& r) {
        int tx = (int)size()-1, ty = (int)r.size()-1; 
        for (tx = size()-1; tx >= 0; --tx) if ((*this)[tx] > 0) break;
        for (ty = r.size()-1; ty >= 0; --ty) if (r[ty] > 0) break;
        bint res(0);
        res.resize(tx+ty+2);
        if (sign == r.sign) res.sign = 1;
        else res.sign = -1;
        for (int i = 0; i <= tx; ++i) {
            for (int j = 0; j <= ty && i+j < (int)res.size()-1; ++j) {
                long long val = (*this)[i] * r[j] + res[i+j];
                res[i+j+1] += val / bint::BASE;
                res[i+j] = val % bint::BASE;
            }
        }
        return (*this) = res.normalize();
    }
    friend bint pow(const bint& a, long long n) {
        bint res(1), b = a;
        while (n > 0) {
            if (n & 1) res = res * b;
            b = b * b;
            n >>= 1;
        }
        return res;
    }
    bint operator + (const bint& r) const { return bint(*this) += r; }
    bint operator - (const bint& r) const { return bint(*this) -= r; }
    bint operator * (long long r) const { return bint(*this) *= r; }
    bint operator * (const bint& r) const { return bint(*this) *= r; }

    // divide
    bint& operator /= (long long r) {
        if (r < 0) sign *= -1, r = -r;
        long long c = 0, t = 0;
        for (int i = (int)size()-1; i >= 0; --i) {
            t = bint::BASE * c + (*this)[i];
            (*this)[i] = t / r;
            c = t % r;
        }
        this->normalize();
        return (*this);
    }
    long long operator %= (long long r) {
        if (r < 0) sign *= -1, r = -r;
        long long c = 0, t = 0;
        for (int i = (int)size()-1; i >= 0; --i) {
            t = bint::BASE * c + (*this)[i];
            (*this)[i] = t / r;
            c = t % r;
        }
        return c;
    }
    bint operator / (long long r) const {
        return bint(*this) /= r;
    }
    long long operator % (long long r) const {
        return bint(*this) %= r;
    }
    friend pair<bint, bint> divmod(const bint &a, const bint &r) {
        bint zero = 0, s = 0, t = 0;
        if (abs(a) < abs(r)) return {zero, a};
        bint ar = abs(r);
        s.resize((int)a.size()), t.resize((int)r.size());
        int tx = (int)a.size()-1;
        for (;tx >= 0; --tx) if (a[tx] > 0) break;
        for (int i = tx; i >= 0; --i) {
            t = t * bint::BASE + a[i];
            long long lo = 0, hi = bint::BASE;
            if (t >= ar) {
                while (hi - lo > 1) {
                    int mid = (hi + lo) / 2;
                    if (ar * mid > t) hi = mid;
                    else lo = mid;
                }
                t -= ar * lo;
            }
            s[i] = lo;
        }
        if (a.sign == r.sign) s.sign = 1, t.sign = 1;
        else s.sign = -1, t.sign = 1;
        return make_pair(s.normalize(), t.normalize());
    }
    bint operator / (const bint& r) const {
        return divmod((*this), r).first;
    }
    bint operator % (const bint& r) const {
        return divmod((*this), r).second;
    }
    bint& operator /= (const bint& r) { return (*this) = (*this) / r; }
    bint& operator %= (const bint& r) { return (*this) = (*this) % r; }

    // equality
    friend bool operator < (const bint &x, const bint& y) {
        if (x.sign < y.sign) return true;
        else if (x.sign > y.sign) return false;
        else {
            int tx = (int)x.size()-1, ty = (int)y.size()-1; 
            for (tx = x.size()-1; tx >= 0; --tx) if (x[tx] > 0) break;
            for (ty = y.size()-1; ty >= 0; --ty) if (y[ty] > 0) break;
            if (tx < ty) return true;
            else if (tx > ty) return false;
            else if (x.sign == 1) {
                for (int i = tx; i >= 0; --i)
                    if (x[i] != y[i]) return x[i] < y[i];
                return false;
            }
            else {
                for (int i = tx; i >= 0; --i)
                    if (x[i] != y[i]) return x[i] > y[i];
                return false;
            }
        }
    }
    friend bool operator > (const bint& x, const bint& y) { return y < x; }
    friend bool operator <= (const bint& x, const bint& y) { return !(x > y); }
    friend bool operator >= (const bint& x, const bint& y) { return !(x < y); }
    friend bool operator == (const bint &x, const bint& y) {
        if (x.sign != y.sign) return false;
        int tx = (int)x.size()-1, ty = (int)y.size()-1; 
        for (tx = x.size()-1; tx >= 0; --tx) if (x[tx] > 0) break;
        for (ty = y.size()-1; ty >= 0; --ty) if (y[ty] > 0) break;
        if (tx != ty) return false;
        for (int i = tx; i >= 0; --i)
            if (x[i] != y[i]) return false;
        return true;
    }
    friend bool operator != (const bint& x, const bint& y) { return !(x == y); }
};

bint toBint(const string &is) {
    string s = is;
    if (s[0] == '-') s = s.substr(1);
    while (s.size() % bint::BASE_DIGIT != 0) s = "0" + s;
    int N = (int)s.size();
    bint res(N/bint::BASE_DIGIT, 0);
    for (int i = 0; i < (int)s.size(); ++i) {
        res[(N-i-1)/bint::BASE_DIGIT] *= 10;
        res[(N-i-1)/bint::BASE_DIGIT] += s[i] - '0';
    }
    if (is[0] == '-') res.sign = -1;
    return res;
}

string toStr(const bint &r) {
    stringstream ss;
    if (r.sign == -1) ss << '-';
    int d = (int)r.size()-1; 
    for (; d >= 0; --d) if (r[d] > 0) break;
    if (d == -1) ss << 0;
    else ss << r[d];
    for (int i = d-1; i >= 0; --i) {
        ss.width(bint::BASE_DIGIT);
        ss.fill('0');
        ss << r[i];
    }
    return ss.str();
}

istream &operator >> (istream &is, bint &x) {
    string s; is >> s;
    x = toBint(s);
    return is;
}

ostream &operator << (ostream &os, const bint &x) {
    if (x.sign == -1) os << '-';
    int d = x.size()-1; 
    for (d = x.size()-1; d >= 0; --d) if (x[d] > 0) break;
    if (d == -1) os << 0;
    else os << x[d];
    for (int i = d-1; i >= 0; --i) {
        os.width(bint::BASE_DIGIT);
        os.fill('0');
        os << x[i];
    }
    return os;
}

// edge class (for network-flow)
template<class FLOWTYPE, class COSTTYPE> struct Edge {
    int rev, from, to, id;
    FLOWTYPE cap, icap;
    COSTTYPE cost;
    Edge(int r, int f, int t, FLOWTYPE ca, COSTTYPE co, int id = -1) :
        rev(r), from(f), to(t), cap(ca), icap(ca), cost(co), id(id) {}
    friend ostream& operator << (ostream& s, const Edge& E) {
        if (E.cap > 0)
            return s << E.from << "->" << E.to <<
                '(' << E.cap << ',' << E.cost << ')';
        else return s;
    }
};

// graph class (for network-flow)
template<class FLOWTYPE, class COSTTYPE> struct Graph {
    vector<vector<Edge<FLOWTYPE, COSTTYPE> > > list;
    
    Graph(int n = 0) : list(n) { }
    void init(int n = 0) { list.clear(); list.resize(n); }
    void reset() { for (int i = 0; i < (int)list.size(); ++i) for (int j = 0; j < list[i].size(); ++j) list[i][j].cap = list[i][j].icap; }
    inline vector<Edge<FLOWTYPE, COSTTYPE> >& operator [] (int i) { return list[i]; }
    inline const size_t size() const { return list.size(); }
    
    inline Edge<FLOWTYPE, COSTTYPE> &redge(const Edge<FLOWTYPE, COSTTYPE> &e) {
        if (e.from != e.to) return list[e.to][e.rev];
        else return list[e.to][e.rev + 1];
    }
    
    void addedge(int from, int to, FLOWTYPE cap, COSTTYPE cost, int id = -1) {
        list[from].push_back(Edge<FLOWTYPE, COSTTYPE>((int)list[to].size(), from, to, cap, cost, id));
        list[to].push_back(Edge<FLOWTYPE, COSTTYPE>((int)list[from].size() - 1, to, from, 0, -cost));
    }
    
    void add_undirected_edge(int from, int to, FLOWTYPE cap, COSTTYPE cost, int id = -1) {
        list[from].push_back(Edge<FLOWTYPE, COSTTYPE>((int)list[to].size(), from, to, cap, cost, id));
        list[to].push_back(Edge<FLOWTYPE, COSTTYPE>((int)list[from].size() - 1, to, from, cap, cost, id));
    }

    friend ostream& operator << (ostream& s, const Graph& G) {
        s << endl;
        for (int i = 0; i < G.size(); ++i) {
            s << i << ":";
            for (auto e : G.list[i]) s << " " << e;
            s << endl;
        }
        return s;
    }   
};

// min-cost flow (by primal-dual)
template<class FLOWTYPE, class COSTTYPE> COSTTYPE MinCostFlow(Graph<FLOWTYPE, COSTTYPE> &G, int s, int t, FLOWTYPE f) {
    int n = (int)G.size();
    vector<COSTTYPE> pot(n, 0), dist(n, -1);
    vector<int> prevv(n), preve(n);
    COSTTYPE res = 0;
    while (f > 0) {
        priority_queue<pair<COSTTYPE,int>, vector<pair<COSTTYPE,int> >, greater<pair<COSTTYPE,int> > > que;
        dist.assign(n, -1);
        dist[s] = 0;
        que.push(make_pair(0,s));
        while(!que.empty()) {
            pair<COSTTYPE,int> p = que.top();
            que.pop();
            int v = p.second;
            if (dist[v] < p.first) continue;
            for (int i = 0; i < G[v].size(); ++i) {
                auto e = G[v][i];
                if (e.cap > 0 && (dist[e.to] < 0 || dist[e.to] > dist[v] + e.cost + pot[v] - pot[e.to])) {
                    dist[e.to] = dist[v] + e.cost + pot[v] - pot[e.to];
                    prevv[e.to] = v;
                    preve[e.to] = i;
                    que.push(make_pair(dist[e.to], e.to));
                }
            }
        }
        if (dist[t] < 0) return -1;
        for (int v = 0; v < n; ++v) pot[v] += dist[v];
        FLOWTYPE d = f;
        for (int v = t; v != s; v = prevv[v]) {
            d = min(d, G[prevv[v]][preve[v]].cap);
        }
        f -= d;
        res += pot[t] * d;
        for (int v = t; v != s; v = prevv[v]) {
            Edge<FLOWTYPE,COSTTYPE> &e = G[prevv[v]][preve[v]];
            Edge<FLOWTYPE,COSTTYPE> &re = G.redge(e);
            e.cap -= d;
            re.cap += d;
        }
    }
    return res;
}

int main() {
    long long x = 55;
    vector<bint> power(300, 1);
    for (int i = 1; i < 300; ++i) power[i] = power[i-1] * x;
    bint M = power.back();

    int N;
    cin >> N;
    vector<string> fi(N);
    for (int i = 0; i < N; ++i) cin >>fi[i];
    int s = N*2, t = N*2+1;
    Graph<int,bint> G(N*2+2);
    auto cost = [&](char c) {
        if (c >= 'A' && c <= 'Z') return M - power[51-(c-'A')];
        else return M - power[25-(c-'a')];
    };
    for (int i = 0; i < N; ++i) {
        G.addedge(s, i, 1, 0), G.addedge(i+N, t, 1, 0);
        for (int j = 0; j < N; ++j) G.addedge(i, j+N, 1, cost(fi[i][j]));
    }
    auto mincost = MinCostFlow(G, s, t, N);

    // reconstruct
    mincost = M * N - mincost;
    vector<long long> hist;
    while (mincost > 0) {
        hist.push_back(mincost % x);
        mincost /= x;
    }
    string res = "";
    for (int i = 0; i < hist.size(); ++i) {
        char c;
        if (i < 26) c = 'z'-i;
        else c = 'Z'-(i-26);
        for (int j = 0; j < hist[i]; ++j) res += c;
    }
    reverse(res.begin(), res.end());
    cout << res << endl;
}