コーナーケースがえぐい!!
僕は最初、(1, -1), (-1, 3) で Yes を返してしまっていた。
問題概要
個の区間 があって、
- 両端の座標は のいずれか
- 両端の座標をかき集めたとき、重複がない
- 区間 と区間 がもし重なっているならば、区間 の長さと、区間 の長さが等しい
という条件を満たしていた。いま、これらの区間の情報の一部が欠けた状態で入力が与えられる (欠損値は -1 で表す)。また、一部の情報が書き換わっている可能性もある。
上の条件をすべて満たすように全区間を復元することが可能かどうかを判定せよ (一意でなくともよい)。
制約
考えたこと
「区間の重なりがあるならば長さが等しい」という訳のわからない条件があるので、こういうのはわかりやすく言い換えて行くことを考える。
まず「2 つの区間の重なり方」について考える。大きく分けて下の 2 つのパターンがあるが、左側はダメで、からなず右側のような配置関係にしなければならない。
さらに、複数の区間が重なり合っている部分について考える。それらを上手に並び替えると、からなず下図の右側のように、
「両端をピックアップすると連続する自然数となる」
という風になっていることがわかる。左側のように、空き地があるのはダメ。
区間 DP へ
よって条件を満たすような区間配置は、次のように言える。
- 重なる区間を同じグループにして、グループ分けしたとき
- それぞれのグループは、 を等差数列として、 という 個の区間からなる
こうなったら次のような DP を考えるのは自然。
- dp[ v ] := 座標区間 [0, v) の範囲について、上記のようにグループ分けできるか
そして dp[ 2N ] が最終的な答えとなる。遷移は次のようにする。
dp[j] |= dp[i] (座標区間 (i, j) でグループを形成できるとき)
計算量を評価する
- 最初にすべての座標区間に対して、グループを形成できるかどうかを判定する
- 区間 DP をする
によって、全体として となる。なお、すべての座標区間に対してグループを形成できるかどうかを判定する部分は、コーナーケースがとても多い。僕が引っかかったのは、こういうケース
(1, -1), (-1, 3)
このケースでは座標区間 [1, 4) でグループを形成することはできないが、できるものとしてしまっていた。
コード
#include <bits/stdc++.h> using namespace std; int N; vector<int> LtoR, RtoL; bool isValid(int l, int r) { if (l >= r) return false; if ((r - l) % 2 == 1) return false; int m = (l + r) / 2; int d = (r - l) / 2; for (int i = m; i < r; ++i) if (LtoR[i] != -2) return false; for (int i = l; i < m; ++i) if (RtoL[i] != -2) return false; for (int i = l; i < m; ++i) { if (LtoR[i] >= 0 && LtoR[i] != i+d) return false; if (LtoR[i] == -1 && RtoL[i+d] != -2) return false; } for (int i = m; i < r; ++i) { if (RtoL[i] >= 0 && RtoL[i] != i-d) return false; if (RtoL[i] == -1 && LtoR[i-d] != -2) return false; } return true; } bool solve() { set<int> se; cin >> N; LtoR.assign(N*2, -2), RtoL.assign(N*2, -2); bool ok = true; for (int i = 0; i < N; ++i) { int a, b; cin >> a >> b; if (a != -1) --a; if (b != -1) --b; if (a != -1) { LtoR[a] = b; if (se.count(a)) ok = false; se.insert(a); } if (b != -1) { RtoL[b] = a; if (se.count(b)) ok = false; se.insert(b); } if (a != -1 && b != -1 && a >= b) ok = false; } if (!ok) return false; // pre-calc vector<vector<bool>> valid(N*2+1, vector<bool>(N*2+1, false)); for (int l = 0; l < N*2; ++l) { for (int r = l+1; r <= N*2; ++r) { valid[l][r] = isValid(l, r); } } // dp vector<bool> dp(N*2+1, false); dp[0] = true; for (int i = 0; i < N*2; ++i) { if (!dp[i]) continue; for (int j = i+1; j <= N*2; ++j) { if (valid[i][j]) dp[j] = true; } } return dp[N*2]; } int main() { if (solve()) cout << "Yes" << endl; else cout << "No" << endl; }