けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

JOI 春合宿 2007 day3-2 Route (難易度 7)

DIjkstra をするときに、直前の頂点ももつ系

問題概要

頂点数  N、辺数  M の重み付き無向グラフが与えられる。頂点  i の座標は  (X_{i}, Y_{i}) となっている。

頂点 1 から頂点 2 へと至る経路のうち、鋭角に曲がる箇所がないようなものを考える (頂点  v の前の頂点を  u v の次の頂点を  w としたとき、線分  uv と線分  uw のなす角が 90 度未満)。

そのような経路の最短路長を求めよ。

制約

  •  2 \le N \le 100
  •  0 \le M \le \frac{N(N-1)}{2}
  • 座標値の絶対値は  10000 以下

考えたこと

基本的には Dijkstra 法を使うことで最短路が求められる。ただし、頂点  v から次の頂点  nv へと行けるかどうかを判定するために、頂点  v の前の頂点がどうだったかの情報が必要になる。よって、

  • dp[pv][v] ← 始点 (頂点 1) から頂点  v へと至る経路のうち、頂点  v の直前の頂点が  pv であるようなものの長さの最小値

として、Dijkstra 法で扱っていく。これは頂点集合を拡張したグラフ上の Dijkstra 法とみなすことができる。あとは、以下の判定関数が作れれば OK。

// 頂点 pv から頂点 v へ行き、そこから頂点 nv へ行けるか
auto isvalid = [&](int pv, int v, int nv) -> bool {
    ang = (線分 v-pv と線分 v-nv のなす角)
    if (dif < π/2 - EPS) return false;
    else return true;
};

コード

ここでは幾何ライブラリで殴ることにした。計算量を解析すると

  • 拡張したグラフの頂点数: O(M)
  • 拡張したグラフの辺数: O(NM)

となるので、priority_queue を使って Dijkstra をすると  O(NM \log N) となる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

// 幾何
using DD = double;
const DD EPS = 1e-10;        // to be set appropriately
const DD PI = acosl(-1.0);
DD torad(int deg) {return (DD)(deg) * PI / 180;}
DD todeg(DD ang) {return ang * 180 / PI;}

struct Point {
    DD x, y;
    Point(DD x = 0.0, DD y = 0.0) : x(x), y(y) {}
    friend ostream& operator << (ostream &s, const Point &p) {return s << '(' << p.x << ", " << p.y << ')';}
};
inline Point operator + (const Point &p, const Point &q) {return Point(p.x + q.x, p.y + q.y);}
inline Point operator - (const Point &p, const Point &q) {return Point(p.x - q.x, p.y - q.y);}
inline Point operator * (const Point &p, DD a) {return Point(p.x * a, p.y * a);}
inline Point operator * (DD a, const Point &p) {return Point(a * p.x, a * p.y);}
inline Point operator * (const Point &p, const Point &q) {return Point(p.x * q.x - p.y * q.y, p.x * q.y + p.y * q.x);}
inline Point operator / (const Point &p, DD a) {return Point(p.x / a, p.y / a);}
inline Point conj(const Point &p) {return Point(p.x, -p.y);}
inline Point rot(const Point &p, DD ang) {return Point(cos(ang) * p.x - sin(ang) * p.y, sin(ang) * p.x + cos(ang) * p.y);}
inline Point rot90(const Point &p) {return Point(-p.y, p.x);}
inline DD cross(const Point &p, const Point &q) {return p.x * q.y - p.y * q.x;}
inline DD dot(const Point &p, const Point &q) {return p.x * q.x + p.y * q.y;}
inline DD norm(const Point &p) {return dot(p, p);}
inline DD abs(const Point &p) {return sqrt(dot(p, p));}
inline DD amp(const Point &p) {DD res = atan2(p.y, p.x); if (res < 0) res += PI*2; return res;}

const long long INF = 1LL<<60;
int main() {
    int N, M;
    cin >> N >> M;
    vector<Point> vp(N);
    for (int i = 0; i < N; ++i) cin >> vp[i].x >> vp[i].y;

    auto isvalid = [&](int pv, int v, int nv) -> bool {
        DD pang = amp(vp[pv] - vp[v]), nang = amp(vp[nv] - vp[v]);
        DD dif = abs(pang - nang);
        if (dif < PI/2 - EPS) return false;
        else return true;
    };

    using pint = pair<int, int>; // prev node, current node
    using Edge = pair<int, long long>;
    using Graph = vector<vector<Edge>>;
    Graph G(N);
    for (int i = 0; i < M; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        --u, --v;
        G[u].emplace_back(v, w), G[v].emplace_back(u, w);
    } 
    int s = 0;
    vector<vector<long long>> dp(N, vector<long long>(N, INF));
    priority_queue<pair<long long, pint>,
                   vector<pair<long long, pint>>,
                   greater<pair<long long, pint>>> que;
    for (auto e: G[s]) {
        dp[s][e.first] = e.second;
        que.push(make_pair(dp[s][e.first], pint(s, e.first)));
    }
    while (!que.empty()) {
        auto tmp = que.top();
        que.pop();
        long long dis = tmp.first;
        int pv = tmp.second.first, v = tmp.second.second;
        if (dis > dp[pv][v]) continue;
        for (auto e: G[v]) {
            int nv = e.first;
            if (!isvalid(pv, v, nv)) continue;
            if (chmin(dp[v][nv], dp[pv][v] + e.second)) {
                que.push(make_pair(dp[v][nv], pint(v, nv)));
            }
        }
    }
    long long res = INF;
    for (int v = 0; v < N; ++v) chmin(res, dp[v][1]);
    cout << (res < INF ? res : -1) << endl;
}