본문 바로가기
CS/Algorithm

1717 set expression (집합의 표현) with Union-Find in Python3

by 빠니몽 2023. 4. 23.

0. Background

To explain why this problem should be solved with union-find not Set in Python.

 

1. Union-Find

1-1. What is it?

A data structure to handle data that is splited in subsets. It is also a sort of disjoint-set.

 

1-2. How does it work?

As the name sounds like, this algorithm is used to get disjoint-set by unioning nodes and finding parents.

 

1-2-1. Find

Find is implemented with a recursive funtion.

def union(a, b):
    a = find(a)
    b = find(b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

1-2-2. Union

Union puts two different sets together and the highest number becomes the parent node of the united set.

def find_parent(x):
    # if not the root, call itself reculsively till one found
    if parent[x] != x:
        parent[x] = find_parent(parent[x])
    return parent[x]

 

2. Reason you can't use Set

This code started from my thought what if Set in Python is used for this problem? Can it be used?

So I tried, and it resulted in out of memory. It was mysterious for me since the maximum of N is one million, but as I know of, the maximum capacity of memory is 10 millions given the memory size which was 128MB.

as I searched up, I got the answer.

 

Set takes up about 3 times more memory tahn just list elements.

>>> from sys import getsizeof
>>> a=[i for i in range(1000)]
>>> b={i for i in range(1000)}
>>> getsizeof(a)
        9024
>>> getsizeof(b)
        32992

Unlike list which assign storages for the amount of the data type a user tries to put, Set and Dictionary need to caluate hash values resulting in taking up more spaces. This was why the code of mine with Set turned out Out of Memory.

3. Final Code

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**7)

def find(parents, x):
    if parents[x] != x:
        parents[x] = find(parents, parents[x])
    return parents[x]

def union(parents, a, b):
    a = find(parents, a)
    b = find(parents, b)
    if a < b:
        parents[b] = a
    else:
        parents[a] = b

n, m = list(map(int,input().split()))
parents = [i for i in range(n+1)]

for i in range(m):
    instruct, a, b = list(map(int, input().split()))
    if instruct == 1 and a == b:
        print("yes")
        continue
    elif instruct == 0 and a == b:
        continue
            
    if instruct == 0:
        union(parents, a,b)
    else:
        if find(parents,a) == find(parents,b):
            print("yes")
        else:
            print("no")