1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
# python3
class Database:
def __init__(self, row_counts):
self.row_counts = row_counts
self.max_row_count = max(row_counts)
n_tables = len(row_counts)
self.ranks = [1] * n_tables
self.parents = list(range(n_tables))
self.set_id = list(range(n_tables))
def find(self, i):
if i != self.parents[i]:
self.parents[i] = self.find(self.parents[i])
return self.parents[i]
def merge(self, src, dst):
real_dst = self.find(dst)
real_src = self.find(src)
if real_dst == real_src:
return
self.max_row_count = max(self.max_row_count,
self.row_counts[self.set_id[real_dst]] + self.row_counts[self.set_id[real_src]])
self.row_counts[real_dst] = self.row_counts[self.set_id[real_dst]] + self.row_counts[self.set_id[real_src]]
self.row_counts[real_src] = 0
self.set_id[real_src] = real_dst
self.set_id[real_dst] = real_dst
self.union(src, dst)
def union(self, src, dst):
destination_id = self.find(dst)
source_id = self.find(src)
if destination_id == source_id:
return
if self.ranks[destination_id] > self.ranks[source_id]:
self.parents[source_id] = destination_id
else:
self.parents[destination_id] = source_id
if self.ranks[destination_id] == self.ranks[source_id]:
self.ranks[source_id] = self.ranks[source_id] + 1
def main():
n_tables, n_queries = map(int, input().split())
counts = list(map(int, input().split()))
assert len(counts) == n_tables
db = Database(counts)
for i in range(n_queries):
dst, src = map(int, input().split())
db.merge(dst - 1, src - 1)
print(db.max_row_count)
if __name__ == "__main__":
main()
|