포시코딩

[MST][최소 신장 트리] Prim's algorithm 본문

자료구조알고리즘/이론

[MST][최소 신장 트리] Prim's algorithm

포시 2023. 5. 10. 21:22
728x90

프림 알고리즘(Prim's algorithm)

  • Kruskal's algorithm과 함께 대표적인 최소 신장 트리(Minimum Spanning Tree) 알고리즘 중 하나
  • Kruskal's algorithm보다 다소 복잡도가 올라간 알고리즘이다. 
  • 시작 정점을 선택한 후, 정점에 인접한 간선 중 최소 간선으로 연결된 정점을 선택하고, 
    해당 정점에서 다시 최소 간선으로 연결된 정점을 선택하는 방식으로 최소 신장 트리를 확장해가는 방식

 

Kruskal's algorithm과 Prim's algorithm 비교

  • 둘 다 탐욕 알고리즘을 기초로 하고 있음(당장 눈 앞의 최소 비용을 선택해서, 결과적으로 최적의 솔루션을 찾음)
  • Kruskal's algorithm은 가장 가중치가 작은 간선부터 선택하면서 MST를 구한다면
  • Prim's algorithm은 특정 정점(노드)에서 시작, 해당 정점에 연결된 가장 가중치가 작은 간선을 선택, 
    간선으로 연결된 정점들에 연결된 간선 중에서 가장 가중치가 작은 간선을 택하는 방식으로 MST를 구한다.
  • Kruskal's algorithm
    전체 정보를 다 알고 있다는 가정하에 모든 간선 리스트를 한꺼번에 정렬 후 가장 낮은 가중치를 가진 간선을 선택하는 방식
  • Prim's algorithm
    선택된 노드를 기준으로 점차 정보를 추가해가며 한정된 정보하에 가장 낮은 가중치를 가진 간선을 선택하는 방식

 

사용방법

  1. 임의의 정점을 선택, '연결된 노드 집합'에 삽입
  2. 선택된 정점에 연결된 간선들을 '간선 리스트'에 삽입
  3. 간선 리스트에서 최소 가중치를 가지는 간선부터 추출해서, 
    • 해당 간선에 연결된 인접 정점이 '연결된 노드 집합'에 이미 들어 있다면, 스킵(cycle 방지)
    • 해당 간선에 연결된 인접 정점이 '연결된 노드 집합'에 들어 있지 않으면, 해당 간선을 선택하고, 
      해당 간선 정보를 '최소 신장 트리'에 삽입
  4. 추출한 간선은 간선 리스트에서 제거
  5. 간선 리스트에 더 이상의 간선이 없을 때까지 3~4 반복

 

사용 라이브러리

heapq를 통한 운선순위 큐 사용

import heapq

queue = list()
graph_data = [[5, 'B'], [2, 'A'], [3, 'C']]

# for edge in graph_data:
#   heapq.heappush(queue, edge)

# for index in range(len(queue)):
#   print(heapq.heappop(queue))

# while queue:
#   print(heapq.heappop(queue))

heapq.heapify(graph_data)

while graph_data:
  print(heapq.heappop(graph_data))

 

 

collections 라이브러리의 defaultdict 함수 활용

  • defaultdict 함수를 사용해서, key에 대한 value를 지정하지 않았을 시, 빈 리스트로 초기화
  • 미리 초기화하지 않아도 됨!
from collections import defaultdict

list_dict = defaultdict(list)  
# list -> 초기화된 값이 없을 경우 리턴할 초기 값에 대한 데이터 타입
print(list_dict['key1'])  # []

list_dict2 = defaultdict(int)
print(list_dict2['A'])  # 0

 

코드 구현

그래프 코드

  • 이미 작성된 간선 정보(중복)는 피해서 작성
myedges = [
  (7, 'A', 'B'), (5, 'A', 'D'),                 # A에 연결된 간선
  (8, 'B', 'C'), (9, 'B', 'D'), (7, 'B', 'E'),  # B에 연결된 간선
  (5, 'C', 'E'),                                # C에 연결된 간선
  (7, 'D', 'E'), (6, 'D', 'F'),                 # D에 연결된 간선
  (8, 'E', 'F'), (9, 'E', 'G'),                 # E에 연결된 간선
  (11, 'F', 'G')                                # F에 연결된 간선
]

알고리즘 구현 코드

from collections import defaultdict
from heapq import *

def prim(start_node, edges):
  mst = list()

  adjacent_edges = defaultdict(list)
  for weight, n1, n2 in edges:
    adjacent_edges[n1].append((weight, n1, n2))
    adjacent_edges[n2].append((weight, n2, n1))
  
  connected_nodes = set(start_node)  # 연결된 노드 집합
  candidate_edge_list = adjacent_edges[start_node]  # 선택된 노드에 대한 간선 리스트
  # candidate_edge_list(간선 리스트) 중 가중치가 제일 작은 걸 pop하기 위해 heap구조로 변환
  heapify(candidate_edge_list)

  while candidate_edge_list:
    weight, n1, n2 = heappop(candidate_edge_list)
    if n2 not in connected_nodes:  # 없을 경우 사이클 x
      connected_nodes.add(n2)
      mst.append((weight, n1, n2))  # 최소 신장 트리의 간선으로 선택

      # 후보군 추가
      for edge in adjacent_edges[n2]:  # edge -> (weight, n1, n2)
        if edge[2] not in connected_nodes:  # 이미 연결된 노드에 연결되는 간선은 고려 x
          heappush(candidate_edge_list, edge)

  return mst

print(prim('A', myedges))
# [(5, 'A', 'D'), (6, 'D', 'F'), (7, 'A', 'B'), (7, 'B', 'E'), (5, 'E', 'C'), (9, 'E', 'G')]

 

시간복잡도

  • adjacent_edges 초기화에 간선 수 만큼 반복 -> O(E)
  • 최악의 경우, while문에서 모든 간선에 대해 반복 -> O(E)
  • 최소 힙 구조 사용 -> O(logE)
  • 결과적으로 O(ElogE)의 시간복잡도를 가짐

 

+ 알고리즘 개선

  • 기존의 간선 중심 알고리즘이 아닌 노드를 중심으로 우선순위 큐를 적용하는 방식
  • 노드의 수가 간선의 수보다 작다는 점(최대 V^2 = E가 될 수 있음)을 이용해서 시간복잡도 개선
    • 각각의 노드에 초기값을 우선순위 큐에 세팅(특정 노드의 값: 0, 나머지 노드: 무한대)
    • 가장 값이 작은 노드를 뽑아(pop) 연결된 간선의 가중치가 연결된 노드의 값보다 작으면
      해당 노드의 값을 가중치 값으로 업데이트
    • 위 과정 반복
      * 2번의 업데이트 시 heap 내부 값도 빼내어 업데이트 후 다시 넣어줘야 하는데 
      파이썬에서 굳이 그런 과정 없이도 내부 값을 업데이트 하는 방법을 제공한다. -> heapdict

 

코드

from heapdict import heapdict
# 설치가 필요하다면 -> pip install HeapDict

def prim(graph, start):
  mst, keys, pi, total_weight = list(), heapdict(), dict(), 0
  for node in graph.keys():
    keys[node] = float('inf')
    pi[node] = None
  keys[start], pi[start] = 0, start

  while keys:
    current_node, current_key = keys.popitem()
    mst.append([pi[current_node], current_node, current_key])
    total_weight += current_key
    for adjacent, weight in graph[current_node].items():  # key 값을 업데이트 시키는 부분
      if adjacent in keys and weight < keys[adjacent]:
        keys[adjacent] = weight  # key값 업데이트 시 자동으로 heap 구조 변경
        pi[adjacent] = current_node
  return mst, total_weight

mygraph = {
  'A': {'B': 7, 'D': 5},
  'B': {'A': 7, 'D': 9, 'C': 8, 'E': 7},
  'C': {'B': 8, 'E': 5},
  'D': {'A': 5, 'B': 9, 'E': 7, 'F': 6},
  'E': {'B': 7, 'C': 5, 'D': 7, 'F': 8, 'G': 9},
  'F': {'D': 6, 'E': 8, 'G': 11},
  'G': {'E': 9, 'F': 11}    
}
mst, total_weight = prim(mygraph, 'A')
print ('MST:', mst)
print ('Total Weight:', total_weight)
728x90