けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 007 D - 破れた宿題 (試験管橙色)

一瞬激ヤバに見えるし、コーナーケースの数もえげつないけど、とりあえず最小の初項はすぐにわかると...

問題へのリンク

問題概要

等差数列があった。
等差数列を concat して得られる文字列から、先頭から何文字かと、末尾から何文字かを削除して得られた文字列  S (長さ  N) が与えられる。もとの等差数列として考えられるもののうち、(初項, 交差) の辞書順最小値を求めよ。ただし以下の条件を満たすとする

  •  S に残っている文字列は、初項を表す文字列の一部が残っている
  • 元の等差数列は、どの項も leading zero を含まない

制約

  •  1 \le N \le 1000

まずは初項

とてもつらい気持ちになるやつだけど、一つの道しるべとして、「初項」として考えられる最小値だけは一瞬でわかる。

  • S[0] != '0' のとき
    • S[1:] の先頭から  k 文字が 0 であるとき、S[0] + S.substr(1, k) が初項の最小値
  • S の先頭  k 文字が 0 であるとすると、"1" + S.substr(0, k) が初項の最小値
    •  k = N のとき (S = "00...0" のとき) も同様 (この場合は交差の最小値は "1")

たとえばこんな感じになる:

  • "3421" では、初項は "3"
  • "300421" では、初項は "300"
  • "00089" では、初項は "1000"
  • "0" では、初項は "10"
  • "000" では、初項は "1000"

さて、初項がわかってから第二項以降は 2 つの場合がある。

  • 第二項がすべて S に含まれる場合
  • 第二項が途切れている場合

第二項がすべて S に含まれる場合

こちらは単純に「第二項と第三項の切れ目」を全探索して、整合性をチェックすれば OK

第二項が途切れている場合

こちらは、第二項の候補として、以下の二通りを試せば OK

  • 初項 + 1
  • (S の初項を削った残りの値) ×  10^{k} が、初項より大きくなる最小の値

前者が鬼のコーナーケースとなっている。たとえば

  • S = "201" のとき、初項は 20、交差は 80 (第二項が 100)
  • S = "202" のとき、初項は 20、交差は 1 (第二項が 21)
  • S = "203" のとき、初項は 20、交差は 10 (第二項が 30)

という感じ。

コード

計算量は  O(N^{2}) となる。 この問題のために、多倍長整数ライブラリを整えた!!!

#include <iostream>
#include <string>
#include <vector>
#include <sstream>
using namespace std;

const int DEFAULT_SIZE = 125;
struct bint : vector<long long> {
    static const long long BASE = 100000000;
    static const int BASE_DIGIT = 8;
    int sign;

    // constructor
    bint(long long num) : vector<long long>(DEFAULT_SIZE, 0), sign(1) {
        if (num < 0) sign = -1, num = -num;
        (*this)[0] = num; 
        this->normalize();
    }
    bint& normalize() {
        long long c = 0;
        for (int i = 0;; ++i) {
            if (i >= this->size()) this->push_back(0);
            if ((*this)[i] < 0 && i+1 >= this->size()) this->push_back(0);
            while ((*this)[i] < 0) {
                (*this)[i+1] -= 1;
                (*this)[i] += BASE;
            }
            long long a = (*this)[i] + c;
            (*this)[i] = a % BASE;
            c = a / BASE;
            if (c == 0 && i == this->size()-1) break;
        }
        return (*this);
    }
    friend bint abs(const bint &x) {
        bint z = x;
        if (z.sign == -1) z.sign = 1;
        return z;
    }
    friend ostream &operator << (ostream &os, const bint &x) {
        if (x.sign == -1) os << '-';
        int d = x.size()-1; 
        for (d = x.size()-1; d >= 0; --d) if (x[d] > 0) break;
        if (d == -1) os << 0;
        else os << x[d];
        for (int i = d-1; i >= 0; --i) {
            os.width(bint::BASE_DIGIT);
            os.fill('0');
            os << x[i];
        }
        return os;
    }

    // operation
    bint operator - () const {
        bint res = *this;
        bool allzero = true;
        for (int i = 0; i < this->size(); ++i) {
            if (res[i] != 0) {
                allzero = false;
                break;
            }
        }
        if (!allzero) res.sign = -res.sign; 
        return res; 
    }
    bint& operator += (const bint& r) {
        while (size() < r.size()) this->emplace_back(0);
        if (sign == r.sign) {
            for (int i = 0; i < r.size(); ++i) (*this)[i] += r[i];
        }
        else {
            if (sign == 1 && abs(*this) < abs(r)) sign = -1;
            else if (sign == -1 && abs(*this) <= abs(r)) sign = 1;
            if (abs(*this) >= abs(r)) {
                for (int i = 0; i < r.size(); ++i) (*this)[i] -= r[i];
            }
            else {
                for (int i = 0; i < size(); ++i) (*this)[i] = -(*this)[i];
                for (int i = 0; i < r.size(); ++i) (*this)[i] += r[i];
            }
        }
        return this->normalize();
    }
    bint& operator -= (const bint& r) {
        while (size() < r.size()) this->emplace_back(0);
        if (sign == -r.sign) {
            for (int i = 0; i < r.size(); ++i) (*this)[i] += r[i];
        }
        else {
            if (sign == 1 && abs(*this) < abs(r)) sign = -1;
            else if (sign == -1 && abs(*this) <= abs(r)) sign = 1;
            if (abs(*this) >= abs(r)) {
                for (int i = 0; i < r.size(); ++i) (*this)[i] -= r[i];
            }
            else {
                for (int i = 0; i < size(); ++i) (*this)[i] = -(*this)[i];
                for (int i = 0; i < r.size(); ++i) (*this)[i] += r[i];
            }
        }
        return this->normalize();
    }
    bint& operator *= (long long r) {
        if ( (sign == 1 && r >= 0) || (sign == -1 && r < 0) ) sign = 1;
        else sign = -1;
        if (r < 0) r = -r;
        for (int i = 0; i < size(); ++i) (*this)[i] *= r;
        return this->normalize();
    }
    bint& operator *= (const bint& r) {
        int tx = (int)size()-1, ty = (int)r.size()-1; 
        for (tx = size()-1; tx >= 0; --tx) if ((*this)[tx] > 0) break;
        for (ty = r.size()-1; ty >= 0; --ty) if (r[ty] > 0) break;
        bint res(0);
        res.resize(tx+ty+2);
        if (sign == r.sign) res.sign = 1;
        else res.sign = -1;
        for (int i = 0; i <= tx; ++i) {
            for (int j = 0; j <= ty && i+j < (int)res.size()-1; ++j) {
                long long val = (*this)[i] * r[j] + res[i+j];
                res[i+j+1] += val / bint::BASE;
                res[i+j] = val % bint::BASE;
            }
        }
        return (*this) = res.normalize();
    }
    friend bint pow(const bint& a, long long n) {
        bint res(1), b = a;
        while (n > 0) {
            if (n & 1) res = res * b;
            b = b * b;
            n >>= 1;
        }
        return res;
    }
    bint operator + (const bint& r) const { return bint(*this) += r; }
    bint operator - (const bint& r) const { return bint(*this) -= r; }
    bint operator * (long long r) const { return bint(*this) *= r; }
    bint operator * (const bint& r) const { return bint(*this) *= r; }

    // divide
    bint& operator /= (long long r) {
        long long c = 0, t = 0;
        for (int i = (int)size()-1; i >= 0; --i) {
            t = bint::BASE * c + (*this)[i];
            (*this)[i] = t / r;
            c = t % r;
        }
        this->normalize();
        return (*this);
    }
    long long operator %= (long long r) {
        long long c = 0, t = 0;
        for (int i = (int)size()-1; i >= 0; --i) {
            t = bint::BASE * c + (*this)[i];
            (*this)[i] = t / r;
            c = t % r;
        }
        return c;
    }
    bint operator / (long long r) const {
        return bint(*this) /= r;
    }
    long long operator % (long long r) const {
        return bint(*this) %= r;
    }
    friend pair<bint, bint> divmod(const bint &a, const bint &r) {
        bint zero = 0, s = 0, t = 0;
        if (abs(a) < abs(r)) return {zero, a};
        bint ar = abs(r);
        s.resize((int)a.size()), t.resize((int)r.size());
        int tx = (int)a.size()-1;
        for (;tx >= 0; --tx) if (a[tx] > 0) break;
        for (int i = tx; i >= 0; --i) {
            t = t * bint::BASE + a[i];
            long long lo = 0, hi = bint::BASE;
            if (t >= ar) {
                while (hi - lo > 1) {
                    int mid = (hi + lo) / 2;
                    if (ar * mid > t) hi = mid;
                    else lo = mid;
                }
                t -= ar * lo;
            }
            s[i] = lo;
        }
        return make_pair(s.normalize(), t.normalize());
    }
    bint operator / (const bint& r) const {
        return divmod((*this), r).first;
    }
    bint operator % (const bint& r) const {
        return divmod((*this), r).second;
    }
    bint& operator /= (const bint& r) { return (*this) = (*this) / r; }
    bint& operator %= (const bint& r) { return (*this) = (*this) % r; }

    // equality
    friend bool operator < (const bint &x, const bint& y) {
        if (x.sign < y.sign) return true;
        else if (x.sign > y.sign) return false;
        else {
            int tx = (int)x.size()-1, ty = (int)y.size()-1; 
            for (tx = x.size()-1; tx >= 0; --tx) if (x[tx] > 0) break;
            for (ty = y.size()-1; ty >= 0; --ty) if (y[ty] > 0) break;
            if (tx < ty) return true;
            else if (tx > ty) return false;
            else if (x.sign == 1) {
                for (int i = tx; i >= 0; --i)
                    if (x[i] != y[i]) return x[i] < y[i];
                return false;
            }
            else {
                for (int i = tx; i >= 0; --i)
                    if (x[i] != y[i]) return x[i] > y[i];
                return false;
            }
        }
    }
    friend bool operator > (const bint& x, const bint& y) { return y < x; }
    friend bool operator <= (const bint& x, const bint& y) { return !(x > y); }
    friend bool operator >= (const bint& x, const bint& y) { return !(x < y); }
    friend bool operator == (const bint &x, const bint& y) {
        if (x.sign != y.sign) return false;
        int tx = (int)x.size()-1, ty = (int)y.size()-1; 
        for (tx = x.size()-1; tx >= 0; --tx) if (x[tx] > 0) break;
        for (ty = y.size()-1; ty >= 0; --ty) if (y[ty] > 0) break;
        if (tx != ty) return false;
        for (int i = tx; i >= 0; --i)
            if (x[i] != y[i]) return false;
        return true;
    }
    friend bool operator != (const bint& x, const bint& y) { return !(x == y); }
};

bint toBint(const string &s) {
    bint res = 0;
    for (int i = 0; i < s.size(); ++i) {
        res += (long long)(s[i] - '0');
        if (i != s.size()-1) res *= 10;
    }
    return res;
}

string toStr(const bint &r) {
    stringstream ss;
    if (r.sign == -1) ss << '-';
    int d = (int)r.size()-1; 
    for (; d >= 0; --d) if (r[d] > 0) break;
    if (d == -1) ss << 0;
    else ss << r[d];
    for (int i = d-1; i >= 0; --i) {
        ss.width(bint::BASE_DIGIT);
        ss.fill('0');
        ss << r[i];
    }
    return ss.str();
}



// 初項が syoko, 第二項を niko とすることが可能かどうか
// S の先頭は niko の先頭から始まる
bool isValid(const string &S, const bint &syoko, const bint &niko) {
    if (niko <= syoko) return false;
    int N = (int)S.size();
    string sniko = toStr(niko);
    int si = 0;
    for (int i = 0; i < sniko.size() && si < N; ++i) {
        if (sniko[i] != S[si++]) return false;
    }
    if (si == N) return true;

    bint prev = syoko, cur = niko;
    while (si < N) {
        if (S[si] == '0') return false;
        bint next = cur * 2 - prev;
        string snext = toStr(next);
        for (int i = 0; i < snext.size() && si < N; ++i) {
            if (snext[i] != S[si++]) return false;
        }
        if (si < N) prev = cur, cur = next;
    }
    return true;
}

// 試す
void judge(const string &S, const bint& syoko, const bint& niko_koho, bint& niko) {
    if (!isValid(S, syoko, niko_koho)) return;
    if (niko == 0) niko = niko_koho;
    else if (niko_koho < niko) niko = niko_koho;
}

// 解く
void solve(string S) {
    // 初項
    if (S[0] == '0') S = "1" + S;
    int zeronum = 0;
    for (int i = 1; i < S.size(); ++i) {
        if (S[i] != '0') break;
        ++zeronum;
    }
    bint syoko = toBint(S.substr(0, zeronum + 1));
    S = S.substr(zeronum+1);
    int N = (int)S.size();
    if (N == 0) {
        cout << syoko << " " << 1 << endl;
        return;
    }

    bint niko = 0;
    for (int i = 1; i <= N; ++i) {
        bint niko_koho = toBint(S.substr(0, i));
        if (niko_koho <= syoko) continue;
        judge(S, syoko, niko_koho, niko);
    }

    // +1 が valid か
    judge(S, syoko, syoko + 1, niko);

    // 0 を繋げていく
    bint niko_koho = toBint(S);
    while (niko_koho <= syoko) niko_koho *= 10;
    judge(S, syoko, niko_koho, niko);

    // 出力
    cout << syoko << " " << niko - syoko << endl;
}

int main() {
    string S;
    while (cin >> S) solve(S);
}

Codeforces Round #614 (Div. 1) A. NEKO's Maze Game (R1400)

いくらなんでも 2 分で解ける問題とは思えないのですが...

問題へのリンク

問題概要

2 × N のグリッドが与えられる。最初はグリッドのマスはすべて「通路」の状態であって、マス (1, 1) からマス (2, N) に到達することができる。以下の q 個のクエリに答えよ。

  •  i 回目のクエリは (r, c) で与えられ、その時点でもしマス (r, c) が「通路」ならば「壁」になり、「壁」ならば「通路」となる。
  • その処理を行った後で、(1, 1) から (2, N) へ到達できるかどうかを判定せよ

制約

  •  1 \le N \le 10^{5}

考えたこと

これそんなに簡単なの!?
通り抜けられる条件は、以下の条件を満たすことと同値である (0-indexed に直している)

  • 任意の (r, c) に対して、以下のすべてが成立する:
    • (r, c) と (1-r, c-1) のどちらかは通路である
    • (r, c) と (1-r, c) のどちらかは通路である
    • (r, c) と (1-r, c+1) のどちらかは通路である

そこで、この条件の違反している部分の個数を常に更新することで、クエリに答えることができる。

#include <iostream>
#include <vector>
using namespace std;

int main() {
    int N, Q;
    cin >> N >> Q;
    vector<vector<int> > turo(2, vector<int>(N, 1));
    int dame = 0;
    for (int q = 0; q < Q; ++q) {
        int r, c; cin >> r >> c; --r, --c;
        // 新たに壁に
        if (turo[r][c]) {
            for (int dc = -1; dc <= 1; ++dc) {
                int nc = c + dc;
                if (nc < 0 || nc >= N) continue;
                if (!turo[1-r][nc]) ++dame;
            }
            turo[r][c] = false;
        }
        // 壁を消す
        else {
            for (int dc = -1; dc <= 1; ++dc) {
                int nc = c + dc;
                if (nc < 0 || nc >= N) continue;
                if (!turo[1-r][nc]) --dame;
            }
            turo[r][c] = true;
        }
        if (!dame) cout << "Yes" << endl;
        else cout << "No" << endl;
    }
}

Codeforces Round #614 (Div. 1) C. Xenon's Attack on the Gangs (R2300)

面白かった!!!こういうのを確実に通せるようにならないと!!!

問題へのリンク

問題概要

 N 頂点の木が与えられる。木の各辺を  0, 1, \dots, N-2 のラベルをつける方法のうち、

  •  \sum_{1 \le u \lt v \le N} f(u, v)

の値の最大値を求めよ。ただし  f(u, v) は、2 頂点  u, v を結ぶパスに含まれる辺の値の集合を考えたときに、それに含まれない最小の 0 以上の整数値を表す。

制約

  •  2 \le N \le 3000

考えたこと

 f(u, v) の意味が釈然としないので、求める値を、次のように言い換えてみよう。

  • 0 の辺を含むパスの本数
  • 0 の辺と 1 の辺をともに含むようなパスの本数
  • 0 の辺と 1 の辺と 2 の辺をすべて含むようなパスの本数
  • ...

を合計した値に等しくなる。ここまでは少し考えればわかる。

まず 0 を含むパスの本数というのは、「0 の辺」でツリー全体を 2 つに分離したときに、それぞれの部分木のサイズの積に等しい。

同様にして、0 の辺と 1 の辺をともに含むようなパスの本数は、それらの辺をともに含むようなパスをとってきたときに、その両端点を根とする部分木のサイズの積に等しい。

観察

上図では、0 の辺と 1 の辺を引き離して描いたけど、これは明らかに損である。数値を入れ替えて、0 と 1 が隣接するようにした方が得になる。

同様にして、0, 1, 2, ... とラベルを増やしていくときに、新たな値をつける辺を見定めるときには、「すでに値を振ったパス」の両端に接続するようにするのが良いということがいえる。ただし葉に到達したら終了。それ以降は何をやっても値は増えない。

ちゃんとした証明は editorial に。

DP へ

以上のことから、次のような DP が立つ

  • dp[ u ][ v ] := 無の状態から開始して、パス (u, v) まで伸ばす過程で合算されたスコアの最大値

これは、「パスを伸ばす直前がどうだったか」で、以下の 2 通りの場合分けをして遷移できる。ここで s(u, v) を、v を根とし、u を親方向の頂点とした場合の v を根とする部分木のサイズとする。また、

  • u-v パスにおいて、u から v の方向へ 1 辺進んだ点を up
  • u-v パスにおいて、v から u の方向へ 1 辺進んだ点を vp

とする。このとき、遷移は以下のように表せる。

u-v パスが、up-v パスから辺 (up, u) を伸ばすことで成長した場合

chmax(dp[ u ][ v ], dp[ up ][ v ] + s(up, u) × s(vp, v))

u-v パスが、u-vp パスから辺 (vp, v) を伸ばすことで成長した場合

chmax(dp[ u ][ v ], dp[ u ][ vp ] + s(up, u) × s(vp, v))

計算量

まとめると、以下の値を前処理して求めておけば、 O(N^{2}) な DP で求めることができる。

  • s(u, v)
  • u-v パスにおいて、u から v 方向に 1 辺進んだときの頂点

これらはライブラリにしてしまうことにした。前処理は  O(N\log{N}) を要する。上記のうちの後者は LCA を求める要領で  O(\log{N}) で求める。この場合、全体の計算量は  O(N^{2}\log{N}) となる。

ただし今回は前処理に  O(N^{2}) かけてよく、上記の値をすべてメモリに置いておくこともできるので、全体として  O(N^{2}) で解くことができる (公式解答)。

#include <iostream>
#include <vector>
#include <map>
using namespace std;

using Graph = vector<vector<int> >;
struct RunTree {
    // id[v][w] := the index of node w in G[v]
    vector<map<int,int> > id;

    // num[v][i] := the size of subtree of G[v][i] with parent v
    vector<vector<long long> > num;

    // size(u, v) := the size of subtree v with parent u
    long long size(int u, int v) {
        return num[u][id[u][v]];
    }

    // lca(u, v)
    int getLCA(int u, int v) {
        if (depth[u] > depth[v]) swap(u, v);
        for (int i = 0; i < (int)parent.size(); ++i)
            if ( (depth[v] - depth[u]) & (1<<i) )
                v = parent[i][v];
        if (u == v) return u;
        for (int i = (int)parent.size()-1; i >= 0; --i) {
            if (parent[i][u] != parent[i][v]) {
                u = parent[i][u];
                v = parent[i][v];
            }
        }
        return parent[0][u];
    }

    // length(u, v)
    long long length(int u, int v) {
        int lca = getLCA(u, v);
        return depth[u] + depth[v] - depth[lca]*2;
    }

    // getParent(v, p) := the parent of v directed for p
    int getParent(int v, int p) {
        if (v == p) return -1;
        int lca = getLCA(v, p);
        if (lca != v) return parent[0][v];
        for (int i = (int)parent.size()-1; i >= 0; --i) {
            if (parent[i][p] != -1 && depth[parent[i][p]] > depth[v]) {
                p = parent[i][p];
            }
        }
        return p;
    }
    
    // rec
    vector<vector<int> > parent;
    vector<int> depth;
    int rec(const Graph &G, int v, int p = -1, int d = 0) {
        int p_index = -1;
        int sum = 1;
        parent[0][v] = p;
        depth[v] = d;
        for (int i = 0; i < (int)G[v].size(); ++i) {
            int ch = G[v][i];
            id[v][ch] = i;
            if (ch == p) {
                p_index = i;
                continue;
            }
            int s = rec(G, ch, v, d+1);
            num[v][i] = s;
            sum += s;
        }
        if (p_index != -1) num[v][p_index] = (int)G.size() - sum;
        return sum;
    }

    // init
    void init(const Graph &G) {
        int N = (int)G.size();
        id.assign(N, map<int,int>());
        num.assign(N, vector<long long>());
        for (int v = 0; v < N; ++v) num[v].assign((int)G[v].size(), 0);
         int V = (int)G.size();
        int h = 1;
        while ((1<<h) < N) ++h;
        parent.assign(h, vector<int>(N, -1));
        depth.assign(N, -1);
        rec(G, 0);
        for (int i = 0; i+1 < (int)parent.size(); ++i)
            for (int v = 0; v < V; ++v)
                if (parent[i][v] != -1)
                    parent[i+1][v] = parent[i][parent[i][v]];
    }
};



const long long INF = 1LL<<60;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

using pint = pair<int,int>;
int N;
Graph G;
RunTree rt;

vector<vector<long long> > dp;
long long dprec(int u, int v) {
    if (dp[u][v] != INF) return dp[u][v];
    if (dp[v][u] != INF) return dp[v][u];
    if (u == v) return 0;

    long long res = 0;
    int up = rt.getParent(u, v);
    int vp = rt.getParent(v, u);
    chmax(res, dprec(up, v));
    chmax(res, dprec(vp, u));
    res += rt.size(up, u) * rt.size(vp, v);
    return dp[u][v] = dp[v][u] = res;
}

long long solve() {
    rt.init(G);
    dp.assign(N, vector<long long>(N, INF));
    long long res = 0;
    for (int i = 0; i < N; ++i) {
        for (int j = i+1; j < N; ++j) {
            chmax(res, dprec(i, j));
        }
    }
    return res;
}

int main() {
    while (scanf("%d", &N) != EOF) {
        G.assign(N, vector<int>());
        for (int i = 0; i < N-1; ++i) {
            int u, v;
            scanf("%d %d", &u, &v);
            --u, --v;
            G[u].push_back(v);
            G[v].push_back(u);
        }
        cout << solve() << endl;
    }
}

Codeforces Round #609 (Div. 1) C. K Integers (R2300)

とてもこどふぉっぽい問題だと思った!!!こういうのを得意になるぞー!!!

問題へのリンク

問題概要

 1, 2, \dots, N の順列が与えられる。各  k = 1, 2, \dots, N に対して、以下の問いに答えよ。

  • 順列の隣り合う 2 要素を swap して、順列のどこかの場所で  1, 2, \dots, k がこの順に連続で並んでいる状態にしたい
  • それを実現する最小 swap 回数を求めよ。

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

 k に関する問題は以下の二部パートにわかれそう

  •  1, 2, \dots, k を連続した場所に集める
  • それらの転倒数を求める

後者については、 1, 2, \dots, k-1 の部分についての転倒数に、 k を足すことで転倒数がいくつ増えるのかを BIT で求めることができる。これは「いつもの転倒数の求め方」と一緒なので、難しくない。

前者については、 1, 2, \dots, k の存在する index を  b_{1}, \dots, b_{k} としたとき、

  •  b_{1}, \dots, b_{k} のメディアンについてはその位置を固定
  • その前後をメディアンのところに持ってくる

とするのが最小になる。このことは、よくある「絶対値の和を最小にするのはメディアンのとき」というのとまったく同様にして証明できる。

僕はメディアンの位置を求めて、その前後に集めるのに必要な操作回数を求める作業は priority_queue を 2 つ使う実装をした。でもよく考えたら、転倒数を求めるのに使った BIT を使いまわせばよかった。

いずれにしても計算量は  O(N\log{N})

メディアンを求めるのに、BIT を使い回す方法

#include <iostream>
#include <vector>
#include <queue>
using namespace std;
#define COUT(x) cout << #x << " = " << (x) << " (L" << __LINE__ << ")" << endl
 
template <class Abel> struct BIT {
    const Abel UNITY_SUM = 0;                     // to be set
    vector<Abel> dat;
    
    /* [1, n] */
    BIT(int n) : dat(n + 1, UNITY_SUM) { }
    void init(int n) { dat.sssign(n + 1, UNITY_SUM); }
    
    /* a is 1-indexed */
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i += i & -i)
            dat[i] = dat[i] + x;
    }
    
    /* [1, a], a is 1-indexed */
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i)
            res = res + dat[i];
        return res;
    }
    
    /* [a, b), a and b are 1-indexed */
    inline Abel sum(int a, int b) {
        return sum(b - 1) - sum(a - 1);
    }

    /* k-th number (k is 0-indexed) */
    int get(long long k) {
        ++k;
        int res = 0;
        int N = 1; while (N < (int)dat.size()) N *= 2;
        for (int i = N / 2; i > 0; i /= 2) {
            if (res + i < (int)dat.size() && dat[res + i] < k) {
                k = k - dat[res + i];
                res = res + i;
            }
        }
        return res + 1;
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat.size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};
 
int N;
vector<int> p, ip; // ip: 逆順列
 
vector<long long> solve() {
    vector<long long> res(N, 0);
    BIT<long long> bit(N+10), bit2(N+10);
    long long tensum = 0;
    for (int i = 0; i < N; ++i) {
        // 追加転倒数
        long long add_tentou = bit.sum(ip[i]+1, N + 5);
        tensum += add_tentou;
        
        // 情報更新
        bit.add(ip[i]+1, 1);
        bit2.add(ip[i]+1, ip[i]);
        
        // メディアン
        long long medi = bit.get(i/2) - 1; // 0-index に

        // メディアンに集める量
        long long left = i/2;
        long long right = i - i/2;
        long long left_sum = bit2.sum(1, medi + 1);
        long long right_sum = bit2.sum(medi + 2, N + 5);
        long long left_move = ((medi - 1) + (medi - left)) * (i/2) / 2 - left_sum;
        long long right_move = right_sum - ((medi + 1) + (medi + right)) * right / 2;
        res[i] = tensum + left_move + right_move;
    }
 
    return res;
}
 
int main() {
    while (scanf("%d", &N) != EOF) {
        p.resize(N); ip.resize(N);
        for (int i = 0; i < N; ++i) {
            scanf("%d", &p[i]), --p[i];
            ip[p[i]] = i;
        }
        auto res = solve();
        for (int i = 0; i < N; ++i) {
            if (i) printf(" ");
            printf("%lld", res[i]);
        }
        printf("\n");
    }
}

メディアンを求めるのに、priority_queue を 2 つ使った方法

#include <iostream>
#include <vector>
#include <queue>
using namespace std;
 
template <class Abel> struct BIT {
    const Abel UNITY_SUM = 0;                     // to be set
    vector<Abel> dat;
    
    /* [1, n] */
    BIT(int n) : dat(n + 1, UNITY_SUM) { }
    void init(int n) { dat.sssign(n + 1, UNITY_SUM); }
    
    /* a is 1-indexed */
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i += i & -i)
            dat[i] = dat[i] + x;
    }
    
    /* [1, a], a is 1-indexed */
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i)
            res = res + dat[i];
        return res;
    }
    
    /* [a, b), a and b are 1-indexed */
    inline Abel sum(int a, int b) {
        return sum(b - 1) - sum(a - 1);
    }

    /* k-th number (k is 0-indexed) */
    int get(long long k) {
        ++k;
        int res = 0;
        int N = 1; while (N < (int)dat.size()) N *= 2;
        for (int i = N / 2; i > 0; i /= 2) {
            if (res + i < (int)dat.size() && dat[res + i] < k) {
                k = k - dat[res + i];
                res = res + i;
            }
        }
        return res + 1;
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat.size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};
 
 
int N;
vector<int> p, ip; // ip: 逆順列
 
vector<long long> solve() {
    vector<long long> tentou(N, 0);
    BIT<long long> bit(N+10);
    for (int i = 0; i < N; ++i) {
        long long tmp = bit.sum(ip[i]+1, N+9);
        if (i > 0) tentou[i] = tentou[i-1] + tmp;
        bit.add(ip[i]+1, 1);
    }
 
    vector<long long> res = tentou;
    priority_queue<long long> zen;
    priority_queue<long long, vector<long long>, greater<long long> > kou;
    long long zensum = 0, kousum = 0;
 
    for (int i = 0; i < N; ++i) {
        long long add = ip[i];
 
        // push to zen
        if (zen.size() == kou.size()) {
            if (kou.empty()) zen.push(add), zensum += add;
            else {
                long long mi = kou.top();
                if (add > mi) {
                    kou.pop(); kousum -= mi;
                    zen.push(mi); zensum += mi;
                    kou.push(add); kousum += add;
                }
                else {
                    zen.push(add); zensum += add;
                }
            }
        }
        // push to kou
        else {
            if (zen.empty()) kou.push(add), kousum += add;
            else {
                long long ma = zen.top();
                if (add < ma) {
                    zen.pop(); zensum -= ma;
                    kou.push(ma); kousum += ma;
                    zen.push(add); zensum += add;
                }
                else {
                    kou.push(add); kousum += add;
                }
            }
        }
        long long median = zen.top();
        long long zennum = (long long)zen.size() - 1;
        long long kounum = kou.size();
        long long zenkaisa = ((median - 1) + (median - zennum)) * zennum / 2;
        long long koukaisa = ((median + 1) + (median + kounum)) * kounum / 2;
 
        long long zenadd = zenkaisa - (zensum - median);
        long long kouadd = kousum - koukaisa;
        res[i] += zenadd + kouadd;
    }
    return res;
}
 
int main() {
    while (scanf("%d", &N) != EOF) {
        p.resize(N); ip.resize(N);
        for (int i = 0; i < N; ++i) {
            scanf("%d", &p[i]), --p[i];
            ip[p[i]] = i;
        }
        auto res = solve();
        for (int i = 0; i < N; ++i) {
            if (i) printf(" ");
            printf("%lld", res[i]);
        }
        printf("\n");
    }
}

Codeforces Round #609 (Div. 1) B. Domino for Young (R2000)

半分エスパー

問題へのリンク

問題概要

 i 列目の高さが  a_{i} であるようなヤング図形が与えられる。

f:id:drken1215:20200119144557p:plain

これを 1 × 2 のドミノを重ならないように敷き詰めたい。最大で何個置けるか?

制約

  •  1 \le N \le 3 \times 10^{5}

考えたこと

ドミノ敷き詰め系の問題は、市松模様に塗るのが基本ではある。そして、

  • 黒の個数
  • 白の個数

のうちの小さい方が答えではないかと。それで正解だった。

f:id:drken1215:20200119144845p:plain

#include <iostream>
#include <vector>
using namespace std;

int N;
vector<long long> a;
 
long long solve() {
    long long b = 0, w = 0;
    for (int i = 0; i < N; ++i) {
        if (i % 2 == 0) b += a[i]/2, w += (a[i]+1)/2;
        else b += (a[i]+1)/2, w += a[i]/2;
    }
    return min(b, w);
}
 
int main() {
    scanf("%d", &N);
    a.resize(N);
    for (int i = 0; i < N; ++i) scanf("%lld", &a[i]);
    auto res = solve();
    printf("%lld\n", res);
}