diff options
author | Haidong Ji | 2019-03-04 20:51:19 -0600 |
---|---|---|
committer | Haidong Ji | 2019-03-04 20:51:19 -0600 |
commit | 9b38f4a68773bb08fe923136c9faeade6a9794e9 (patch) | |
tree | fba0be238082a91bf0275852038dde81779a1b30 | |
parent | 421f2e783af773712a5e7a799f7cbaca69ab7a21 (diff) |
Merging tables done.
Not too bad since I worked it out in Java. A bit surprised that
Python class is used in starter file, but I made it testable
and wrote test cases. All is well.
-rw-r--r-- | .idea/encodings.xml | 4 | ||||
-rw-r--r-- | sources/merging_tables.py | 56 | ||||
-rw-r--r-- | tests/merging_tablesTest.py | 34 |
3 files changed, 94 insertions, 0 deletions
diff --git a/.idea/encodings.xml b/.idea/encodings.xml new file mode 100644 index 0000000..15a15b2 --- /dev/null +++ b/.idea/encodings.xml @@ -0,0 +1,4 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="Encoding" addBOMForNewFiles="with NO BOM" /> +</project>
\ No newline at end of file diff --git a/sources/merging_tables.py b/sources/merging_tables.py new file mode 100644 index 0000000..985316e --- /dev/null +++ b/sources/merging_tables.py @@ -0,0 +1,56 @@ +# python3 + + +class Database: + def __init__(self, row_counts): + self.row_counts = row_counts + self.max_row_count = max(row_counts) + n_tables = len(row_counts) + self.ranks = [1] * n_tables + self.parents = list(range(n_tables)) + self.set_id = list(range(n_tables)) + + def find(self, i): + if i != self.parents[i]: + self.parents[i] = self.find(self.parents[i]) + return self.parents[i] + + def merge(self, src, dst): + real_dst = self.find(dst) + real_src = self.find(src) + if real_dst == real_src: + return + self.max_row_count = max(self.max_row_count, + self.row_counts[self.set_id[real_dst]] + self.row_counts[self.set_id[real_src]]) + self.row_counts[real_dst] = self.row_counts[self.set_id[real_dst]] + self.row_counts[self.set_id[real_src]] + self.row_counts[real_src] = 0 + self.set_id[real_src] = real_dst + self.set_id[real_dst] = real_dst + self.union(src, dst) + + def union(self, src, dst): + destination_id = self.find(dst) + source_id = self.find(src) + if destination_id == source_id: + return + if self.ranks[destination_id] > self.ranks[source_id]: + self.parents[source_id] = destination_id + else: + self.parents[destination_id] = source_id + if self.ranks[destination_id] == self.ranks[source_id]: + self.ranks[source_id] = self.ranks[source_id] + 1 + + +def main(): + n_tables, n_queries = map(int, input().split()) + counts = list(map(int, input().split())) + assert len(counts) == n_tables + db = Database(counts) + for i in range(n_queries): + dst, src = map(int, input().split()) + db.merge(dst - 1, src - 1) + print(db.max_row_count) + + +if __name__ == "__main__": + main() diff --git a/tests/merging_tablesTest.py b/tests/merging_tablesTest.py new file mode 100644 index 0000000..ed83135 --- /dev/null +++ b/tests/merging_tablesTest.py @@ -0,0 +1,34 @@ +import unittest +from sources import merging_tables + + +class MyTestCase(unittest.TestCase): + def test1(self): + counts = [1, 1, 1, 1, 1] + db = merging_tables.Database(counts) + db.merge(4, 2) + self.assertEqual(2, db.max_row_count) + db.merge(3, 1) + self.assertEqual(2, db.max_row_count) + db.merge(3, 0) + self.assertEqual(3, db.max_row_count) + db.merge(3, 4) + self.assertEqual(5, db.max_row_count) + db.merge(2, 4) + self.assertEqual(5, db.max_row_count) + + def test2(self): + counts = [10, 0, 5, 0, 3, 3] + db = merging_tables.Database(counts) + db.merge(5, 5) + self.assertEqual(10, db.max_row_count) + db.merge(4, 5) + self.assertEqual(10, db.max_row_count) + db.merge(3, 4) + self.assertEqual(10, db.max_row_count) + db.merge(2, 3) + self.assertEqual(11, db.max_row_count) + + +if __name__ == '__main__': + unittest.main() |