けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 193 D - Poker (緑色, 400 点)

発想や考え方はそんなに難しくないんだけど、すごく頭がこんがらがってしまう問題だね...

問題概要

 1, 2, \dots, 9 が表に書かれたカードが  K 枚ずつ、計  9K 枚のカードがあります。

これらのカードをランダムにシャッフルして、高橋くんと青木くんにそれぞれ、4 枚を表向きに、1 枚を裏向きにして配りました。高橋くんに配られたカードが文字列  S として、青木くんに配られたカードが文字列  T として与えられます。 S, T は 5 文字の文字列で、先頭 4 文字は 1, 2, …, 9 からなり、表向きのカードに書かれた数を表します。末尾 1 文字は # であり、裏向きのカードであることを表します。

5 枚の手札の点数を、 c_i をその手札に含まれる  i の枚数として、 \sum_{i=1}^{9} = i \times 10^{c_{i}} で定義します。 高橋くんが青木くんより点数の高い手札を持っていたら高橋くんの勝ちです。

高橋くんの勝つ確率を求めてください。

制約

  •  2 \le K \le 10^{5}

考えたこと

 K \le 10^{5} という制約を見ると、計算量に工夫がいるのかなと錯覚してしまうけど、冷静に考えると

  •  S の 5 枚目のカードがどれか (1, 2, ..., 9 の 9 通り)
  •  T の 5 枚目のカードがどれか (1, 2, ..., 9 の 9 通り)

で、考えられる場合が 81 通りしかない。なので、全部試せば OK!!!81 通りそれぞれに対して、

  • そのパターンが出る確率
  • 高橋くんが勝てるかどうか

を判定できれば OK。後者の「高橋くんが勝てるかどうか」は計算するだけ (ただし 0 回しか登場しない数値にも得点が発生することに注意) なので、前者の「そのパターンが出る確率」を考えよう。

サンプルで確率計算

とりあえずサンプルを使って考えてみよう!サンプル 1 よりもむしろ、サンプル 3 の方が様子を掴みやすい!

6
1122#
2228#

まず、1, 1, 2, 2, 2, 2, 2, 8 を使用した時点で残っているカードは

  • 1 が 4 枚
  • 2 が 1 枚
  • 3 が 6 枚
  • 4 が 6 枚
  • 5 が 6 枚
  • 6 が 6 枚
  • 7 が 6 枚
  • 8 が 5 枚
  • 9 が 6 枚

となっている。合計で 46 枚残っている。1 や 2 の残り枚数が少ないので

  • 「S の最後が 1、T の最後が 2」となる確率
  • 「S の最後が 5、T の最後が 6」となる確率

は等しくならないことに注意しよう。ちゃんと正確に確率を見積もってみよう。まず、ありうるすべての場合の数 (確率の分母) を求めてみる。

  • まず S に入るカードを選ぶ方法が 46 通り (一般には  9K-8 通り)
  • T に入るカードを残りのカードから選ぶ方法が 45 通り (一般には  9K-9 通り)

となるので、全体としては

 46 \times 45 = 2070 通り (一般には、 (9K-8)(9K-9) 通り)

となる。ここで、「組合せ」だと誤解して  \frac{46 \times 45}{2} = 1035 通りとしてしまわないように注意。今回、選んだ 2 枚のカードをそれぞれ S 側に渡すか T 側に渡すかを区別しないといけない!!!

さてサンプル 3 に戻ると、高橋君が勝てるパターンが

  • S の最後が 2
  • T の最後が 1

しかない。そして 2 が残り 1 枚しかないので前者が 1 通りで、1 が残り 2 枚なので後者が 2 通り。よって 2 通り。以上から確率は

 \frac{2}{2070} = \frac{1}{1035}

になる。

S の最後と T の最後が同じ場合と異なる場合

以上の考察を一般化して整理しよう!!!
まず、全体の場合の数は、先ほどのとおり  (9K-8)(9K-9) 通りとなる。

さて、 9K 枚のカードから 8 枚使って残ったカードのうち、 1, 2, \dots, 9 の枚数を  c_{1}, c_{2}, \dots, c_{9} 枚としよう。 S の最後が  i で、 T の最後が  j としたとき、もし  i j が異なるならば、単純に  c_{i} \times c_{j} 通りとなる。

問題は、 i = j となる場合。このときは

  •  c_{i} (= c_{j}) 枚のカードのうち、 S の最後に来るカードを選ぶ方法が  c_{i} 通り
  • その残りの  c_{i}-1 枚のカードのうち、 T の最後に来るカードを選ぶ方法が  c_{i}-1 通り

となるので、 c_{i} \times (c_{i} -1) 通りとなる。以上をまとめると、次のようになる。

  •  i \neq j のとき、確率は  \frac{c_{i}c_{j}}{(9K-8)(9K-9)}
  •  i = j のとき、確率は  \frac{c_{i}(c_{i}-1)}{(9K-8)(9K-9)}

コード

#include <bits/stdc++.h>
using namespace std;

long long score(const string &S) {
    vector<long long> val(10);
    for (int v = 1; v <= 9; ++v) val[v] = v;
    for (auto c: S) val[c-'0'] *= 10;
    long long sum = 0;
    for (int v = 1; v <= 9; ++v) sum += val[v];
    return sum;
}

int main() {
    long long K;
    string S, T;
    cin >> K >> S >> T;

    // 残りカード枚数
    vector<long long> rem(10, K);
    for (int i = 0; i < 4; ++i) rem[S[i]-'0']--, rem[T[i]-'0']--;

    double res = 0.0;
    for (int a = 1; a <= 9; ++a) {
        for (int b = 1; b <= 9; ++b) {
            S[4] = (char)('0' + a), T[4] = (char)('0' + b);

            // ひとまず分子だけ合算する
            if (score(S) > score(T)) {
                if (a != b) res += rem[a] * rem[b];
                else res += rem[a] * (rem[a] - 1);
            }
        }
    }
    res /= (K*9-8) * (K*9-9);
    cout << fixed << setprecision(10) << res << endl;
}

AtCoder ABC 193 C - Unexpressed (灰色, 300 点)

むずかしかった!!!
でも、約数列挙でありがちな「 \sqrt{N} まで試せば良い」という考え方がちゃんと理解できているかを問う良問だった!!

問題概要

整数  N が与えられる。

1 以上  N 以下の整数のうち、 2 以上の整数  a,b を用いて  a^{b} と表せないものはいくつあるか?

制約

  •  1 \le N \le 10^{10}

考えたこと

 1 から  N まで全部調べたのでは間に合わないから、なにか考えないといけない問題だね。

とりあえず  N が小さいと仮定して、全探索することから考えてみよう。そうしたら、 n = 1, 2, \dots, N それぞれに対して

 n a^{b} の形で表せるか」

という判定問題を解くことになる。これは次のようにして解ける。

  •  a = 2, 3, \dots と順にためして
  •  b = 2, 3, \dots と順に試して、 a^{b} = N となったら Yes、 a^{b} \gt N となったら探索終了

しかしこのままでは途方もない計算量となってしまう (ちゃんと解析すると  O(N^{2} \log N) になるはず)。

 a, b の方を動かして行く

ちょっと工夫してみよう。まず、各  n = 1, 2, \dots, N に対して、 a, b を探索していくのは無駄がある (たとえば  2^{2} = 4 などを何回も見て行くことになってしまう)。

そこで、次のように発想を切り替えてみよう。

  •  2^{2}, 2^{3}, 2^{4}, \dots を列挙する
  •  3^{2}, 3^{3}, 3^{4}, \dots を列挙する
  •  4^{2}, 4^{3}, 4^{4}, \dots を列挙する
  • ...

こうして列挙していって、その個数を  N から引けば良さそうだ。具体的にはどこまで探索すればよいだろうか??

それは、 \sqrt{N} までになる。なぜなら  \sqrt{N} より大きい整数は二乗すると  N より大きくなるからだ。よって、次のような探索をすれば良さそう。1 個目の for 文の判定条件を a * a <= N にしている。

for (long long a = 2; a * a <= N; ++a) {
    long long val = a * a;
    while (val <= N) {
        // val を列挙する

        // val に a をかける
        val *= a; 
    }
}

さて、一見これで良さそうだけど、まだ大きな罠がある。このままだと重複が発生するのだ。

  •  3^{4} = 81
  •  9^{2} = 81

という感じだ。この重複を取り除くことを考えよう!!! 

set を使う

重複を取り除くためには、set (C++ でも Python でもともに) を使うのが有効だ。こんなふうにする。

set<long long> ab;
for (long long a = 2; a * a <= N; ++a) {
    long long val = a * a;
    while (val <= N) {
        // val を列挙する
        ab.insert(val);

        // val に a をかける
        val *= a; 
    }
}

これによって、 a^{b} の形で表される整数をすべて列挙できる。その個数を  N から引けば OK。

注意点

重複を取り除こうとするときに、いくつか注意すべきことがある。

注意点 1:vector では MLE する

set が思いつかないと、

  • isab[v] ← 整数 v が  a^{b} の形で表せるなら True、そうでなければ False

という配列を作りたくなると思う。しかしサイズ  N の配列は、必要メモリ量があまりにも多いため、MLE してしまう。

注意点 2:「素数の冪乗だけ考えれば良い」は嘘

たとえば、 9^{2} = 81 3^{4} とも表せることなどから

「素数の冪乗だけ考えればよく、そうすれば重複を除去できるのではないか」

と考えた人も多かったと思う。しかしそれは嘘。

  •  6^{2} = 36

などは、素数の冪乗では表せないのだ。

コード

以上の考察をコードに落とせば OK。計算量は、Python の set を用いた場合は、

  •  a の探索範囲が  O(\sqrt{N}) まで
  • そのそれぞれについて  b の探索範囲が  O(\log N)

ということで  O(\sqrt{N} \log N) となって十分間に合う。C++ の set を用いた場合は  O(\sqrt{N} (\log N)^{2}) となる。

#include <bits/stdc++.h>
using namespace std;

int main() {
    long long N;
    cin >> N;
    set<long long> ab;
    for (long long a = 2; a * a <= N; ++a) {
        long long val = a * a;
        while (val <= N) {
            ab.insert(val);
            val *= a;
        }
    }
    cout << N - ab.size() << endl;
}

AtCoder ABC 077 C - Snuke Festival (ARC 084 C) (緑色, 300 点)

lower_bound の練習に!!! あと、「3 つのものを考えるときは、真ん中を固定して考える」という考え方の典型。

問題概要

3 つの数列 (長さ  N)

  •  a_{0}, a_{1}, \dots, a_{N-1}
  •  b_{0}, b_{1}, \dots, b_{N-1}
  •  c_{0}, c_{1}, \dots, c_{N-1}

が与えられる。各数列から要素  (a_{i}, b_{j}, c_{k}) を選ぶ方法のうち、

 a_{i} \lt b_{j} \lt c_{k}

を満たすものの個数を求めよ。

制約

  •  1 \le N \le 10^{5}

考えたこと

単純にすべてのペア  (a_{i}, b_{j}, c_{k}) を考える方法では  O(N^{3}) の計算量となりますので、工夫が必要です。

まずは考えやすくなるように、各数列は小さい順にソートされているものとしましょう (ソートは  O(N \log N) でできます)。さて、3 つ組について考えるときは、真ん中を固定して考えるのが定石です。つまり、 b_{j} を固定して考えてみましょう。このとき、

  •  a_{0}, a_{1}, \dots, a_{N-1} のうち、 b_{j} 未満の個数を数える ( A_{j} とする)
  •  c_{0}, c_{1}, \dots, c_{N-1} のうち、 b_{j} 以上の個数を数える ( C_{j} とする)

というようにします (下図参照)。 a_{i} \lt b_{j} \lt c_{k} を満たす組の個数は  A_{j} \times C_{j} となります。これを  j = 0, 1, \dots, N-1 について足すことで答えが求められます。

 A_{j} C_{j} を求める

これらはいずれも二分探索 (特に lower_bound) を使って求められます!

まず  a_{0}, a_{1}, \dots, a_{N-1} のうち、 b_{j} 未満の個数を数えてみましょう。lower_bound を使うと、次のようにして  a_{k} \ge b_{j} を満たす最小の  k が求められます。

int k = lower_bound(a.begin(), a.end(), b[j]) - a.begin();

そしてこのとき、 a_{i} \lt b_{j} を満たす  i の個数は  k になります。こうして  A_{j} = k であることがわかりました。

 C_{j} についても同様に求められます。具体的には、 c_{i} \le b_{j} (不等号に等号が付くことに注意) を満たす  i の個数を求めて、それを  N から引けばよいでしょう。そのような  i の個数は、lower_bound() の代わりに upper_bound() を用いることで求められます。

計算量・コード

 b_{j} に対して、 A_{j} C_{j} O(\log N) で求められます。よって全体の計算量は  O(N \log N) となることがわかりました。

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

int main() {
    // 入力
    int N;
    cin >> N;
    vector<long long> a(N), b(N), c(N);
    for (int i = 0; i < N; ++i) cin >> a[i];
    for (int i = 0; i < N; ++i) cin >> b[i];
    for (int i = 0; i < N; ++i) cin >> c[i];

    // ソートする
    sort(a.begin(), a.end());
    sort(b.begin(), b.end());
    sort(c.begin(), c.end());

    // b[j] を固定して考える
    long long res = 0;
    for (int j = 0; j < N; ++j) {
        long long Aj = lower_bound(a.begin(), a.end(), b[j]) - a.begin();
        long long Cj = N - (upper_bound(c.begin(), c.end(), b[j]) - c.begin());
        res += Aj * Cj;
    }
    cout << res << endl;
}

パ研合宿2020 第1日「SpeedRun」 N - 背の順

面白かった!セグ木を使ったけど、区間を複数個にする必要がないことから、しゃくとり法で線形でできるね!

問題概要

 1, 2, \dots, N の順列  A_{1}, A_{2}, \dots, A_{N} が与えられる。以下の操作を繰り返すことで、単調増加となるようにしたい。

  • 区間  \lbrack l, r \rbrack の要素をすべて削除する (コストは  A_{l} + A_{r})

目的を達成するための最小コストを求めよ。

制約

  •  1 \le N \le 2 \times 10^{5}

解法 (1):僕のやった解法

まず、操作する区間が入れ子になったり交差したりする意味はないので、操作は「残す要素を固定すると、その間隔の区間を削除していく」というものになる。

これを踏まえて次の DP をする。

  • dp[v] ← 順列を左から見ていって、値  v が来たときに、値  v を最後の要素としたときの、最小コスト (ただし値 v の右隣の値も合算しておく)

この DP 配列をセグ木 (RMQ) に載せることで、LIS を求める DP と同じ要領で高速に更新していける。ただし DP 更新において、値  v の左隣の要素も選ぶ場合については、別途遷移する。

計算量は  O(N \log N)

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

template<class Monoid> struct RMQ {
    Monoid INF;
    int SIZE_R;
    vector<pair<Monoid,int> > dat;
    
    RMQ() {}
    RMQ(int n, const Monoid &inf): INF(inf) { 
        init(n, inf);
    }
    void init(int n, const Monoid &inf) {
        INF = inf;
        SIZE_R = 1;
        while (SIZE_R < n) SIZE_R *= 2;
        dat.assign(SIZE_R * 2, pair<Monoid,int>(INF, -1));
    }
    
    /* set, a is 0-indexed */
    void set(int a, const Monoid &v) { dat[a + SIZE_R] = make_pair(v, a); }
    void build() {
        for (int k = SIZE_R - 1; k > 0; --k) {
            dat[k] = min(dat[k*2], dat[k*2+1]);
        }
    }
    
    /* update, a is 0-indexed */
    void update(int a, const Monoid &v) {
        int k = a + SIZE_R;
        dat[k] = make_pair(v, a);
        while (k >>= 1) dat[k] = min(dat[k*2], dat[k*2+1]);
    }
    
    /* get {min-value, min-index}, a and b are 0-indexed */
    pair<Monoid,int> get(int a, int b) {
        pair<Monoid,int> vleft = make_pair(INF, -1), vright = make_pair(INF, -1);
        for (int left = a + SIZE_R, right = b + SIZE_R; left < right; left >>= 1, right >>= 1) {
            if (left & 1) vleft = min(vleft, dat[left++]);
            if (right & 1) vright = min(dat[--right], vright);
        }
        return min(vleft, vright);
    }
    inline Monoid operator [] (int a) { return dat[a + SIZE_R].first; }
    
    /* debug */
    void print() {
        for (int i = 0; i < SIZE_R; ++i) {
            Monoid val = (*this)[i];
            if (val < INF) cout << val;
            else cout << "INF";
            if (i != SIZE_R-1) cout << ",";
        }
        cout << endl;
    }
};


const long long INF = 1LL<<60;
int main() {
    int N;
    cin >> N;
    deque<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    A.push_front(0), A.push_back(N+1), A.push_back(0);
    RMQ<long long> dp(N+10, INF);
    dp.update(0, A[1]);
    for (int i = 1; i <= N+1; ++i) {
        long long val = dp.get(0, A[i]).first + A[i-1] + A[i+1];
        if (i-1 >= 0 && A[i-1] < A[i]) {
            chmin(val, dp.get(A[i-1], A[i-1]+1).first - A[i] + A[i+1]);
        }
        dp.update(A[i], val);
    }
    cout << dp.get(N+1, N+2).first << endl;
}

解法 (2)

よくよく考えると、操作する区間は 1 個だけでよい。なぜなら、いくつかの区間について操作するくらいなら、

  • 左端の区間の左端
  • 右端の区間の右端

を選んでまとめて削除してしまった方がよいからだ。そしてそのような条件を満たす区間は単調性 (区間 A が条件を満たすならば、A を包含する区間 B も条件を満たす) を満たすので、しゃくとり法で解ける。この場合計算量は  O(N) となる。

JOI 春合宿 2007 day3-2 Route (難易度 7)

DIjkstra をするときに、直前の頂点ももつ系

問題概要

頂点数  N、辺数  M の重み付き無向グラフが与えられる。頂点  i の座標は  (X_{i}, Y_{i}) となっている。

頂点 1 から頂点 2 へと至る経路のうち、鋭角に曲がる箇所がないようなものを考える (頂点  v の前の頂点を  u v の次の頂点を  w としたとき、線分  uv と線分  uw のなす角が 90 度未満)。

そのような経路の最短路長を求めよ。

制約

  •  2 \le N \le 100
  •  0 \le M \le \frac{N(N-1)}{2}
  • 座標値の絶対値は  10000 以下

考えたこと

基本的には Dijkstra 法を使うことで最短路が求められる。ただし、頂点  v から次の頂点  nv へと行けるかどうかを判定するために、頂点  v の前の頂点がどうだったかの情報が必要になる。よって、

  • dp[pv][v] ← 始点 (頂点 1) から頂点  v へと至る経路のうち、頂点  v の直前の頂点が  pv であるようなものの長さの最小値

として、Dijkstra 法で扱っていく。これは頂点集合を拡張したグラフ上の Dijkstra 法とみなすことができる。あとは、以下の判定関数が作れれば OK。

// 頂点 pv から頂点 v へ行き、そこから頂点 nv へ行けるか
auto isvalid = [&](int pv, int v, int nv) -> bool {
    ang = (線分 v-pv と線分 v-nv のなす角)
    if (dif < π/2 - EPS) return false;
    else return true;
};

コード

ここでは幾何ライブラリで殴ることにした。計算量を解析すると

  • 拡張したグラフの頂点数: O(M)
  • 拡張したグラフの辺数: O(NM)

となるので、priority_queue を使って Dijkstra をすると  O(NM \log N) となる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

// 幾何
using DD = double;
const DD EPS = 1e-10;        // to be set appropriately
const DD PI = acosl(-1.0);
DD torad(int deg) {return (DD)(deg) * PI / 180;}
DD todeg(DD ang) {return ang * 180 / PI;}

struct Point {
    DD x, y;
    Point(DD x = 0.0, DD y = 0.0) : x(x), y(y) {}
    friend ostream& operator << (ostream &s, const Point &p) {return s << '(' << p.x << ", " << p.y << ')';}
};
inline Point operator + (const Point &p, const Point &q) {return Point(p.x + q.x, p.y + q.y);}
inline Point operator - (const Point &p, const Point &q) {return Point(p.x - q.x, p.y - q.y);}
inline Point operator * (const Point &p, DD a) {return Point(p.x * a, p.y * a);}
inline Point operator * (DD a, const Point &p) {return Point(a * p.x, a * p.y);}
inline Point operator * (const Point &p, const Point &q) {return Point(p.x * q.x - p.y * q.y, p.x * q.y + p.y * q.x);}
inline Point operator / (const Point &p, DD a) {return Point(p.x / a, p.y / a);}
inline Point conj(const Point &p) {return Point(p.x, -p.y);}
inline Point rot(const Point &p, DD ang) {return Point(cos(ang) * p.x - sin(ang) * p.y, sin(ang) * p.x + cos(ang) * p.y);}
inline Point rot90(const Point &p) {return Point(-p.y, p.x);}
inline DD cross(const Point &p, const Point &q) {return p.x * q.y - p.y * q.x;}
inline DD dot(const Point &p, const Point &q) {return p.x * q.x + p.y * q.y;}
inline DD norm(const Point &p) {return dot(p, p);}
inline DD abs(const Point &p) {return sqrt(dot(p, p));}
inline DD amp(const Point &p) {DD res = atan2(p.y, p.x); if (res < 0) res += PI*2; return res;}

const long long INF = 1LL<<60;
int main() {
    int N, M;
    cin >> N >> M;
    vector<Point> vp(N);
    for (int i = 0; i < N; ++i) cin >> vp[i].x >> vp[i].y;

    auto isvalid = [&](int pv, int v, int nv) -> bool {
        DD pang = amp(vp[pv] - vp[v]), nang = amp(vp[nv] - vp[v]);
        DD dif = abs(pang - nang);
        if (dif < PI/2 - EPS) return false;
        else return true;
    };

    using pint = pair<int, int>; // prev node, current node
    using Edge = pair<int, long long>;
    using Graph = vector<vector<Edge>>;
    Graph G(N);
    for (int i = 0; i < M; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        --u, --v;
        G[u].emplace_back(v, w), G[v].emplace_back(u, w);
    } 
    int s = 0;
    vector<vector<long long>> dp(N, vector<long long>(N, INF));
    priority_queue<pair<long long, pint>,
                   vector<pair<long long, pint>>,
                   greater<pair<long long, pint>>> que;
    for (auto e: G[s]) {
        dp[s][e.first] = e.second;
        que.push(make_pair(dp[s][e.first], pint(s, e.first)));
    }
    while (!que.empty()) {
        auto tmp = que.top();
        que.pop();
        long long dis = tmp.first;
        int pv = tmp.second.first, v = tmp.second.second;
        if (dis > dp[pv][v]) continue;
        for (auto e: G[v]) {
            int nv = e.first;
            if (!isvalid(pv, v, nv)) continue;
            if (chmin(dp[v][nv], dp[pv][v] + e.second)) {
                que.push(make_pair(dp[v][nv], pint(v, nv)));
            }
        }
    }
    long long res = INF;
    for (int v = 0; v < N; ++v) chmin(res, dp[v][1]);
    cout << (res < INF ? res : -1) << endl;
}