summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHaidong Ji2021-07-05 15:30:18 -0500
committerHaidong Ji2021-07-05 15:30:18 -0500
commit578f8f5874e66a35660eb0759ef7d90a27fbcffe (patch)
treedcc51682d315c94c8db3c0346ea3ed0799034fc4
parent4fd3e8a8f86838720fe43632ce9b06d3d2edaa23 (diff)
Dijkstra shortest path done, using heapq.
Pretty simple after I first implemented in Java.
-rw-r--r--sources/dijkstra.py45
-rw-r--r--tests/test_reachability.py6
2 files changed, 48 insertions, 3 deletions
diff --git a/sources/dijkstra.py b/sources/dijkstra.py
new file mode 100644
index 0000000..e5160d4
--- /dev/null
+++ b/sources/dijkstra.py
@@ -0,0 +1,45 @@
+# Uses python3
+
+import sys
+import heapq
+
+
+def distance(adj, cost, s, t):
+ if s == t:
+ return 0
+ dist = [float('inf')] * len(adj)
+ visited = [False] * len(adj)
+ dist[s] = 0
+ q = []
+ heapq.heappush(q, (dist[s], s))
+
+ while len(q) != 0:
+ u = heapq.heappop(q)[1]
+ if visited[u]:
+ continue
+ visited[u] = True
+ if len(adj[u]) > 0:
+ for i in range(len(adj[u])):
+ if dist[adj[u][i]] > dist[u] + cost[u][i]:
+ dist[adj[u][i]] = dist[u] + cost[u][i]
+ heapq.heappush(q, (dist[adj[u][i]], adj[u][i]))
+
+ if dist[t] == float('inf'):
+ return -1
+ return dist[t]
+
+
+if __name__ == '__main__':
+ input = sys.stdin.read()
+ data = list(map(int, input.split()))
+ n, m = data[0:2]
+ data = data[2:]
+ edges = list(zip(zip(data[0:(3 * m):3], data[1:(3 * m):3]), data[2:(3 * m):3]))
+ data = data[3 * m:]
+ adj = [[] for _ in range(n)]
+ cost = [[] for _ in range(n)]
+ for ((a, b), w) in edges:
+ adj[a - 1].append(b - 1)
+ cost[a - 1].append(w)
+ s, t = data[0] - 1, data[1] - 1
+ print(distance(adj, cost, s, t))
diff --git a/tests/test_reachability.py b/tests/test_reachability.py
index 318309b..8908c30 100644
--- a/tests/test_reachability.py
+++ b/tests/test_reachability.py
@@ -5,10 +5,10 @@ from sources.reachability import reach
class TestReachability(unittest.TestCase):
def testName(self):
- adj = [[1, 3],[], [1],[2]]
+ adj = [[1, 3], [], [1], [2]]
self.assertEqual(1, reach(adj, 0, 3))
if __name__ == "__main__":
- #import sys;sys.argv = ['', 'Test.testName']
- unittest.main() \ No newline at end of file
+ # import sys;sys.argv = ['', 'Test.testName']
+ unittest.main()