์ต์ ์ ์ฅ ํธ๋ฆฌ (MST)
MST Algorithms ์ ์ข
๋ฅ
Prim Algorithm - ์ ์
์ ์ ํ
KRUSKAL Algorithm - ๊ฐ์
์ ์ ํ๋จ
๊ทธ๋ํ์์ ์ต์ ๋น์ฉ ๋ฌธ์
๋ชจ๋ ์ ์ ์ ์ฐ๊ฒฐํ๋ ๊ฐ์ ๋ค์ ๊ฐ์ค์น์ ํฉ์ด ์ต์๊ฐ ๋๋ Tree
๋ ์ ์ ์ฌ์ด์ ์ต์ ๋น์ฉ์ ๊ฒฝ๋ก ์ฐพ๊ธฐ
์ ์ฅ ํธ๋ฆฌ (Spanning Tree)
: n๊ฐ์ ์ ์ ์ผ๋ก ์ด๋ฃจ์ด์ง ๋ฌด๋ฐฉํฅ ๊ทธ๋ํ์์ n๊ฐ์ ์ ์ ๊ณผ n-1๊ฐ์ ๊ฐ์ ์ผ๋ก ์ด๋ฃจ์ด์ง ํธ๋ฆฌ
์ต์ ์ ์ฅ ํธ๋ฆฌ (Mininum Spanning Tree)
: ๋ฌด๋ฐฉํฅ ๊ฐ์ค์น ๊ทธ๋ํ์์ ์ ์ฅํธ๋ฆฌ๋ฅผ ๊ตฌ์ฑํ๋ ๊ฐ์ ๋ค์ ๊ฐ์ค์น์ ํฉ์ด ์ต์์ธ ์ ์ฅํธ๋ฆฌ
MST ํํ
Graph ํํ
๋ถ๋ชจ ์์ ๊ด๊ณ์ ๊ฐ์ค์น์ ๋ํ ๋ฐฐ์ด
Tree
๋ก ๋ํ๋ผ ์ ์๋ค
Prim Algorithm
ํ๋์ ์ ์
์์ ์ฐ๊ฒฐ๋ ๊ฐ์ ๋ค ์ค์ ํ๋์ฉ ์ ํํ๋ฉด์ MST๋ฅผ ๋ง๋ค์ด ๊ฐ๋ ๋ฐฉ์
์์์ ์ ์ ์ ํ๋ ์ ํํด์ ์์
์ ํํ ์ ์ ๊ณผ ์ธ์ ํ๋ ์ ์ ๋ค ์ค์ ์ต์ ๋น์ฉ์ ๊ฐ์ ์ด ์กด์ฌํ๋ ์ ์ ์ ์ ํ
๋ชจ๋ ์ ์ ์ด ์ ํ๋ ๋ ๊น์ง 1,2 ๊ณผ์ ๋ฐ๋ณต
์๋ก์(์ํธ ๋ฐฐํ)์ธ 2๊ฐ์ ์งํฉ(2 disjoint-sets) ์ ๋ณด๋ฅผ ์ ์ง
ํธ๋ฆฌ ์ ์ ๋ค (tree vertices)
MST๋ฅผ ๋ง๋ค๊ธฐ ์ํด ์ ํ๋ ์ ์ ๋ค
๋นํธ๋ฆฌ ์ ์ ๋ค (nontree vertices)
์ ํ ๋์ง ์์ ์ ์ ๋ค
ex)
MST
"""
MST + ์ธ์ ํ๋ ฌ
7 11
0 5 60
0 1 32
0 2 31
0 6 51
1 2 21
2 4 46
2 6 25
3 4 34
3 5 18
4 5 40
4 6 51
์์์ ์ ๋์ ์ ๊ฐ์ค์น
๊ฒฐ๊ณผ
[0, 21, 31, 34, 46, 18, 25]
175
"""
V, E = map(int, input().split())
adj = [[0] * V for _ in range(V)]
for i in range(E):
s, e, c = map(int, input().split())
adj[s][e] = c
adj[e][s] = c # ๋ฌด๋ฐฉํฅ ๊ทธ๋ํ๋ผ์
# for row in adj:
# print(row)
# key, p(parent), mst ์ค๋น
INF = float('inf')
key = [INF] *V # key๋ ๋ฌดํ๋๋ก ์ด๊ธฐํ
p = [-1] * V # p(parent)๋ -1๋ก ์ด๊ธฐํ
mst = [False] * V
# ์์์ ์ ํ : 0๋ฒ ์ ํ
key[0] = 0
cnt = 0
result = 0
while cnt < V:
# ์์ง mst๊ฐ ์๋๊ณ , key๊ฐ ์ต์์ธ ์ ์ ์ ํ : u
MIN = INF
u = -1
for i in range(V):
if not mst[i] and key[i] <MIN:
MIN = key[i]
u = i
# u๋ฅผ mst๋ก ์ ํ
mst[u] = True
result += MIN
cnt+=1
# key ๊ฐ์ ๊ฐฑ์
# u์ ์ธ์ ํ๊ณ , ์์ง mst๊ฐ ์๋ ์ ์ w์์ key[w] > u - w ๊ฐ์ค์น ์ด๋ฉด ๊ฐฑ์ !
for w in range(V):
if adj[u][w] > 0 and not mst[w] and key[w] > adj[u][w]:
key[w] = adj[u][w]
p[w] = u
print(key) # [0, 21, 31, 34, 46, 18, 25]
print(p) # [-1, 2, 0, 4, 2, 3, 2]
print(result) # 175
ex)
Mst - prim algorithm
# Prim Algorithm
# : ํ๋์ ์ ์ ์์ ์ฐ๊ฒฐ๋ ๊ฐ์ ๋ค ์ค์ ํ๋์ฉ ์ ํ ํ๋ฉด์ MST๋ฅผ ๋ง๋ค์ด ๊ฐ๋ ๋ฐฉ์
# ์ฐ์ ์์ ํ ํ์ฉํ๊ธฐ -> ์ด์งํ -> heapq
import heapq
"""
MST + ์ธ์ ํ๋ ฌ
7 11
0 5 60
0 1 32
0 2 31
0 6 51
1 2 21
2 4 46
2 6 25
3 4 34
3 5 18
4 5 40
4 6 51
์์์ ์ ๋์ ์ ๊ฐ์ค์น
๊ฒฐ๊ณผ
[0, 21, 31, 34, 46, 18, 25]
175
"""
V, E = map(int, input().split())
adj = {i: [] for i in range(V)}
for i in range(E):
s, e, c = map(int, input().split())
adj[s].append([e,c])
adj[e].append([s,c]) #๋ฌด๋ฐฉํฅ์ด๋ผ์
# print(adj)
# key, mst, ์ฐ์ ์์ ํ ์ค๋น
INF = float('inf')
key = [INF] * V
mst = [False] * V
pq = []
# ์์ ์ ์ ์ ํ : 0
key[0] = 0
# ํ์ ์์ ์ ์ ์ ๋ฃ์ => (key, ์ ์ index) ๋ฌถ์ด์ ๋ฃ๊ธฐ
# ์ฐ์ ์์ ํ => ์ด์งํ => heapq ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ฌ์ฉ
heapq.heappush(pq, (0,0)) # heqppush(๋ฐฐ์ด ์ ๋ณด, ์ด๋ค ์์ ๋ฃ์ ์ง) : heap์ ๊ตฌ์กฐ๋ฅผ ์ ์งํ๋ฉด์ ํ๋์ ์์๋ฅผ ๋ฃ์ด์ค
# ์ฐ์ ์์ํ -> ์์์ ์ฒซ๋ฒ์งธ ์์ -> key๋ฅผ ์ฐ์ ์์๋ก
result = 0
while pq:
# ์ต์๊ฐ ์ฐพ๊ธฐ
k, node = heapq.heappop(pq) #๋ด๊ฐ ๊ฐ๊ณ ์๋ heapq์์ ์ต์๊ฐ์ popํด์ค
if mst[node]: continue # ๊ฐฑ์ ํ ๋ ํ์ํ๋ ์ ๋ณด๋ skipํ๊ธฐ
# mst๋ก ์ ํ
mst[node] = True
result += k
# key๊ฐ ๊ฐฑ์ => key๋ฐฐ์ด / ํ
for destination, weight in adj[node]:
if not mst[destination] and key[destination] > weight:
key[destination] = weight
# ํ ๊ฐฑ์ => ์๋ก์ด (key, ์ ์ ) ์ฝ์
=> ํ์์๋ ์์๋ skip
heapq.heappush(pq, (key[destination], destination))
print(result) # 175
KRUSKAL Algorithm
: ๊ฐ์ ์ ํ๋์ฉ ์ ํํด์ MST๋ฅผ ์ฐพ๋ ์๊ณ ๋ฆฌ์ฆ
์ต์ด, ๋ชจ๋ ๊ฐ์ ์ ๊ฐ์ค์น
์ ๋ฐ๋ผ ์ค๋ฆ์ฐจ์์ผ๋ก ์ ๋ ฌ
๊ฐ์ค์น๊ฐ ๊ฐ์ฅ ๋ฎ์ ๊ฐ์ ๋ถํฐ ์ ํํ๋ฉด์ ํธ๋ฆฌ๋ฅผ ์ฆ๊ฐ์ํด
์ฌ์ดํด์ด ์กด์ฌํ๋ฉด ๋ค์์ผ๋ก ๊ฐ์ค์น๊ฐ ๋ฎ์ ๊ฐ์ ์ ํ
์ฌ์ดํด์ ์ ํํ์ง ์๋๋ค!
์ฌ์ดํด์ด ์๋ค๋ ๊ฒ์ ์ต๋จ๊ฑฐ๋ฆฌ๋ก ๊ฐ๋ ๊ฒ์ด ์๋๋ผ ๋์์ ๊ฐ๋ ๊ฒ์ ์๋ฏธํ๊ธฐ ๋๋ฌธ
์ฌ์ดํด์ด ์กด์ฌํ๋์ง ํ์ธํ๋ ๋ฒ
์ ์ ์ ๋ํ์๊ฐ ๊ฐ์ผ๋ฉด ์ฌ์ดํด์ด ์๋ค๋ ๊ฒ!
n-1
๊ฐ์ ๊ฐ์ ์ด ์ ํ๋ ๋ ๊น์ง 2๋ฅผ ๋ฐ๋ณต
ex)
kruskal
"""
์์
7 11
0 5 60
0 1 32
0 2 31
0 6 51
1 2 21
2 4 46
2 6 25
3 4 34
3 5 18
4 5 40
4 6 51
"""
def make_set(x):
p[x] = x
def find_set(x):
if p[x] == x:
return x
else:
p[x] = find_set(p[x])
return p[x]
def union(x,y):
px = find_set(x)
py = find_set(y)
if rank[px] > rank[py]:
p[py] = px
else:
p[px] = py
if rank[px] == rank[py]:
rank[py] += 1
V, E = map(int, input().split())
edges = [ list(map(int, input().split()))for _ in range(E)]
# ๊ฐ์ ์ ๊ฐ์ ๊ฐ์ค์น๋ฅผ ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ
edges.sort(key=lambda x: x[2]) # () ์์ ๊ธฐ์ค์ด ๋ค์ด๊ฐ
# make_set : ๋ชจ๋ ์ ์ ์ ๋ํด ์งํฉ ์์ฑ
p = [0] * V
rank = [0] * V
for i in range(V):
make_set(i)
cnt = 0
result = 0
mst = []
# ๋ชจ๋ ๊ฐ์ ์ ๋ํด์ ๋ฐ๋ณต -> V-1๊ฐ์ ๊ฐ์ ์ด ์ ํ๋ ๋ ๊น์ง
for i in range(E):
s, e, c = edges[i][0], edges[i][1], edges[i][2]
# ์ฌ์ดํด์ด๋ฉด skip : ์ฑํํ๋ ค๋ ๋ ์ ์ ์ด ใ
๋ก ๊ฐ์ ์งํฉ์ด๋ฉด skip => find_set ์ด์ฉ
if find_set(s) == find_set(e):
continue
# ๊ฐ์ ์ ํ
result += c
mst.append(edges[i])
# => mst์ ๊ฐ์ ์ ๋ณด ๋ํ๊ธฐ / ๋ ์ ์ ์ ํฉ์น๋ค => Union
union(s,e)
cnt +=1
if cnt == V-1:
break
print(result) # 175
print(mst) #[[3, 5, 18], [1, 2, 21], [2, 6, 25], [0, 2, 31], [3, 4, 34], [2, 4, 46]]