summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.idea/encodings.xml4
-rw-r--r--sources/merging_tables.py56
-rw-r--r--tests/merging_tablesTest.py34
3 files changed, 94 insertions, 0 deletions
diff --git a/.idea/encodings.xml b/.idea/encodings.xml
new file mode 100644
index 0000000..15a15b2
--- /dev/null
+++ b/.idea/encodings.xml
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+ <component name="Encoding" addBOMForNewFiles="with NO BOM" />
+</project> \ No newline at end of file
diff --git a/sources/merging_tables.py b/sources/merging_tables.py
new file mode 100644
index 0000000..985316e
--- /dev/null
+++ b/sources/merging_tables.py
@@ -0,0 +1,56 @@
+# python3
+
+
+class Database:
+ def __init__(self, row_counts):
+ self.row_counts = row_counts
+ self.max_row_count = max(row_counts)
+ n_tables = len(row_counts)
+ self.ranks = [1] * n_tables
+ self.parents = list(range(n_tables))
+ self.set_id = list(range(n_tables))
+
+ def find(self, i):
+ if i != self.parents[i]:
+ self.parents[i] = self.find(self.parents[i])
+ return self.parents[i]
+
+ def merge(self, src, dst):
+ real_dst = self.find(dst)
+ real_src = self.find(src)
+ if real_dst == real_src:
+ return
+ self.max_row_count = max(self.max_row_count,
+ self.row_counts[self.set_id[real_dst]] + self.row_counts[self.set_id[real_src]])
+ self.row_counts[real_dst] = self.row_counts[self.set_id[real_dst]] + self.row_counts[self.set_id[real_src]]
+ self.row_counts[real_src] = 0
+ self.set_id[real_src] = real_dst
+ self.set_id[real_dst] = real_dst
+ self.union(src, dst)
+
+ def union(self, src, dst):
+ destination_id = self.find(dst)
+ source_id = self.find(src)
+ if destination_id == source_id:
+ return
+ if self.ranks[destination_id] > self.ranks[source_id]:
+ self.parents[source_id] = destination_id
+ else:
+ self.parents[destination_id] = source_id
+ if self.ranks[destination_id] == self.ranks[source_id]:
+ self.ranks[source_id] = self.ranks[source_id] + 1
+
+
+def main():
+ n_tables, n_queries = map(int, input().split())
+ counts = list(map(int, input().split()))
+ assert len(counts) == n_tables
+ db = Database(counts)
+ for i in range(n_queries):
+ dst, src = map(int, input().split())
+ db.merge(dst - 1, src - 1)
+ print(db.max_row_count)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/merging_tablesTest.py b/tests/merging_tablesTest.py
new file mode 100644
index 0000000..ed83135
--- /dev/null
+++ b/tests/merging_tablesTest.py
@@ -0,0 +1,34 @@
+import unittest
+from sources import merging_tables
+
+
+class MyTestCase(unittest.TestCase):
+ def test1(self):
+ counts = [1, 1, 1, 1, 1]
+ db = merging_tables.Database(counts)
+ db.merge(4, 2)
+ self.assertEqual(2, db.max_row_count)
+ db.merge(3, 1)
+ self.assertEqual(2, db.max_row_count)
+ db.merge(3, 0)
+ self.assertEqual(3, db.max_row_count)
+ db.merge(3, 4)
+ self.assertEqual(5, db.max_row_count)
+ db.merge(2, 4)
+ self.assertEqual(5, db.max_row_count)
+
+ def test2(self):
+ counts = [10, 0, 5, 0, 3, 3]
+ db = merging_tables.Database(counts)
+ db.merge(5, 5)
+ self.assertEqual(10, db.max_row_count)
+ db.merge(4, 5)
+ self.assertEqual(10, db.max_row_count)
+ db.merge(3, 4)
+ self.assertEqual(10, db.max_row_count)
+ db.merge(2, 3)
+ self.assertEqual(11, db.max_row_count)
+
+
+if __name__ == '__main__':
+ unittest.main()