diff options
author | Haidong Ji | 2021-07-05 15:30:18 -0500 |
---|---|---|
committer | Haidong Ji | 2021-07-05 15:30:18 -0500 |
commit | 578f8f5874e66a35660eb0759ef7d90a27fbcffe (patch) | |
tree | dcc51682d315c94c8db3c0346ea3ed0799034fc4 | |
parent | 4fd3e8a8f86838720fe43632ce9b06d3d2edaa23 (diff) |
Dijkstra shortest path done, using heapq.
Pretty simple after I first implemented in Java.
-rw-r--r-- | sources/dijkstra.py | 45 | ||||
-rw-r--r-- | tests/test_reachability.py | 6 |
2 files changed, 48 insertions, 3 deletions
diff --git a/sources/dijkstra.py b/sources/dijkstra.py new file mode 100644 index 0000000..e5160d4 --- /dev/null +++ b/sources/dijkstra.py @@ -0,0 +1,45 @@ +# Uses python3 + +import sys +import heapq + + +def distance(adj, cost, s, t): + if s == t: + return 0 + dist = [float('inf')] * len(adj) + visited = [False] * len(adj) + dist[s] = 0 + q = [] + heapq.heappush(q, (dist[s], s)) + + while len(q) != 0: + u = heapq.heappop(q)[1] + if visited[u]: + continue + visited[u] = True + if len(adj[u]) > 0: + for i in range(len(adj[u])): + if dist[adj[u][i]] > dist[u] + cost[u][i]: + dist[adj[u][i]] = dist[u] + cost[u][i] + heapq.heappush(q, (dist[adj[u][i]], adj[u][i])) + + if dist[t] == float('inf'): + return -1 + return dist[t] + + +if __name__ == '__main__': + input = sys.stdin.read() + data = list(map(int, input.split())) + n, m = data[0:2] + data = data[2:] + edges = list(zip(zip(data[0:(3 * m):3], data[1:(3 * m):3]), data[2:(3 * m):3])) + data = data[3 * m:] + adj = [[] for _ in range(n)] + cost = [[] for _ in range(n)] + for ((a, b), w) in edges: + adj[a - 1].append(b - 1) + cost[a - 1].append(w) + s, t = data[0] - 1, data[1] - 1 + print(distance(adj, cost, s, t)) diff --git a/tests/test_reachability.py b/tests/test_reachability.py index 318309b..8908c30 100644 --- a/tests/test_reachability.py +++ b/tests/test_reachability.py @@ -5,10 +5,10 @@ from sources.reachability import reach class TestReachability(unittest.TestCase): def testName(self): - adj = [[1, 3],[], [1],[2]] + adj = [[1, 3], [], [1], [2]] self.assertEqual(1, reach(adj, 0, 3)) if __name__ == "__main__": - #import sys;sys.argv = ['', 'Test.testName'] - unittest.main()
\ No newline at end of file + # import sys;sys.argv = ['', 'Test.testName'] + unittest.main() |