けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

最短経路の個数も一緒に数え上げる最短経路アルゴリズム

ARC 090 E - Avoiding Collision で話題になったこともあり、簡単にメモします。

最短経路を求める DP 的処理をするとき、DAG上のDP だろうと、BFS だろうと、Dijkstra だろうと、以下のような緩和処理をやっています

// Edge e の緩和

int dp[MAX_V];  // dp[v] := 始点 s から頂点 v への最短経路長

if (dp[e.to] < dp[e.from] + e.cost) {
 dp[e.to] = dp[e.from] + e.cost;
}

これをちょっと変えるだけで、最短経路の本数も一緒に数え上げられます。

// Edge e の緩和

int dp[MAX_V];  // dp[v] := 始点 s から頂点 v への最短経路長
int num[MAX_V];  // num[v] := 始点 s から頂点 v への最短経路数

if (dp[e.to] < dp[e.from] + e.cost) {
 dp[e.to] = dp[e.from] + e.cost;
 num[e.to] = num[e.from];
}
else if (dp[e.to] == dp[e.from] + e.cost) {
 num[e.to] += num[e.from];
 num[e.to] %= MOD;
}

ARC 090 E - Avoiding Collision


N 頂点 M 辺からなる重みつき無向グラフが与えられる。 高橋君が頂点 S を、青木君が頂点 T を出発して、互いの出発地への最短経路を歩む。 二人の最短路の選び方の組であって、移動の途中で二人が (辺または頂点上で) 出会うことのないようなものの個数を 1000000007 で割ったあまりで求めよ。

・2 <= N <= 105
・1 <= M <= 2*105


S から T への最短経路数 (= T から S への最短経路数) の二乗から、条件を満たさないものを引く。条件を満たさないものは

  • 2人が頂点 v 上でぴったり出会う
  • 2人が枝 (u, v) 上でぴったり出会う (u, v 上を含まない)

である。

#include <iostream>
#include <vector>
#include <queue>
using namespace std;

const long long MOD = 1000000007;
inline long long mod(long long a, long long m) { return (a % m + m) % m; }

typedef pair<long long, int> Edge;

const int MAX = 110000;
const long long INF = 1LL<<59;
int N, M, S, T;
vector<Edge> G[MAX];

long long ds[MAX], ns[MAX], dt[MAX], nt[MAX];

int main() {
    while (cin >> N >> M >> S >> T) {
        --S, --T;
        for (int i = 0; i < MAX; ++i) G[i].clear();
        for (int i = 0; i < M; ++i) {
            int u, v, d;
            cin >> u >> v >> d;
            --u, --v;
            G[u].push_back(Edge(d, v));
            G[v].push_back(Edge(d, u));
        }
        for (int i = 0; i < MAX; ++i) {
            ds[i] = dt[i] = INF;
            ns[i] = nt[i] = 0;
        }
        ds[S] = 0;
        ns[S] = 1;
        priority_queue<Edge, vector<Edge>, greater<Edge> > que;
        que.push(Edge(0, S));
        while (!que.empty()) {
            long long curd = que.top().first;
            int cur = que.top().second;
            que.pop();
            if (ds[cur] < curd) continue;
            for (auto e : G[cur]) {
                if (ds[e.second] > ds[cur] + e.first) {
                    ds[e.second] = ds[cur] + e.first;
                    ns[e.second] = ns[cur];
                    que.push(Edge(ds[e.second], e.second));
                }
                else if (ds[e.second] == ds[cur] + e.first) {
                    ns[e.second] += ns[cur];
                    ns[e.second] %= MOD;
                }
            }
        }
        dt[T] = 0;
        nt[T] = 1;
        que.push(Edge(0, T));
        while (!que.empty()) {
            long long curd = que.top().first;
            int cur = que.top().second;
            que.pop();
            if (dt[cur] < curd) continue;
            for (auto e : G[cur]) {
                if (dt[e.second] > dt[cur] + e.first) {
                    dt[e.second] = dt[cur] + e.first;
                    nt[e.second] = nt[cur];
                    que.push(Edge(dt[e.second], e.second));
                }
                else if (dt[e.second] == dt[cur] + e.first) {
                    nt[e.second] += nt[cur];
                    nt[e.second] %= MOD;
                }
            }
        }

        long long D = ds[T];
        long long res = (ns[T] * nt[S]) % MOD;
        for (int v = 0; v < N; ++v) {
            // v 上を引く
            if (ds[v] == dt[v] && ds[v] + dt[v] == D) {
                long long sub = (ns[v] * nt[v]) % MOD;
                sub = (sub * sub) % MOD;
                res = mod(res - sub, MOD);
            }
            // e = (from -> to) 上を引く
            for (auto e : G[v]) {
                int from = v;
                int to = e.second;
                long long dis = e.first;
                if (ds[from] + dis + dt[to] != D) continue;
                if (ds[from] == dt[from] || ds[to] == dt[to]) continue;
                if (ds[from] < dt[from] && ds[to] > dt[to]) {
                    long long sub = (ns[from] * nt[to]) % MOD;
                    sub = (sub * sub) % MOD;
                    res = mod(res - sub, MOD);
                }
            }
        }
        
        cout << res << endl;
    }
}

さらに一般化

DP の緩和が成立しうる構造をより一般的に抽象的にとらえる議論が kirika さんによってなされています。その立場で「最短経路の個数も一緒に数える最短路DP」を眺めたものとして以下の記事がとても面白いです。

一般的なダイクストラ法 - kirika_compのブログ