Jungle

[TIL][Krafton Jungle] 2주차 - (4) : 위상정렬 알고리즘

손가든 2023. 10. 22. 23:32

오늘은 DFS를 마치고 위상정렬에 대해 공부하기 시작했다.
 


위상 정렬

위상 정렬이란?

방향이 있는 그래프에서 방향을 거스르지 않고 모든 정점을 잇는 경로를 선형순서로 나열하는 것
이때 그래프는 사이클이 없는 방향 그래프(DAG)여야 한다.
 
이때 든 생각은,
알고리즘 문제에서 일단 사이클이 있는지를 판단하고
사이클이 있다면 위상 정렬을 하도록 하는 문제가 있을 거라 생각했다.
사이클이 있는지 판단하는 함수는 Kruskal에서의 노드의 부모를 고정하여 판단하는 방식이 사용될 것 같다.
 

[정방향] 위상 정렬의 동작 예시(경로를 나열하기 위한 진행 순서)는 다음과 같다.

1. 진입차수가 0인 모든 노드를 큐에 삽입한다.
2. 큐에서 POP을 하고 POP한 노드에서 나가는 간선을 제거한다.
3. 제거한 후 진입차선이 0이 된 노드를 다시 큐에 삽입한다.
  -> 이 3가지를 반복
 

근데 이 정방향 위상 정렬은 이해하기 쉽지만 각 정점의 진입 차수가 0인지 확인하기 위해서는 모든 각 연결 정보들을 파악해서 연결되지 않은 정점을 추려내야 하는 단점이 있다. (시간복잡도 측면에서 오래걸린다)
 
따라서 역방향으로 하는 알고리즘이 매우 간편하다는 걸 알게 되었다.
 
[역방향] 위상 정렬의 동작 과정

1. 0번부터 방문하지 않은 노드들을 for문을 돌려 일반 dfs에 넣는다.
2. dfs에 들어온 노드의 진행방향의 노드들을 for문으로 dfs 재귀를 한다.
3. 2번의 for문을 다 끝낸 밑에 결과값 순서 리스트에 노드를 삽입한다.
4. 결과값 순서 리스트를 역순으로 출력한다.
 
 
말로 설명되어 매우 복잡하고 이해하기 어렵지만 코드로 분석하면 매우 간단하다는 걸 알 수 있다. 
다음의 코드를 한번 살펴보자
 

adj_list = [[1],[3,4],[0,1],[6],[5],[7],[8],[]]

N = len(adj_list)
visited = [False] * N
s = []

def dfs(v):
	visited[v] = True
    for w in adj_list[v]:
    	if not visited[w]:
        	dfs(w)
    
	s.append(v)


for i in range(N):
	if not visited[i]:
    	dfs(i)

s.reverse()
print('위상정렬:')
print(s)

####### print 결과물 #######
###[2,0,1,4,5,3,6,7,8]

 

맨 밑에 for문을 통해 모든 노드를 탐색한다 ( visited = True 라면 패스한다)
0번 노드부터 그래프의 방향을 거스르지 않도록 탐색을 시작해서 마지막 끝나는 지점부터 거꾸로 결과에 삽입한다.
만약 0번 노드에 없는 노드(이 그림에선 2번)라면 s에 삽입되지 않고 맨 처음 진행했던 for문을 통해 나중에 삽입된다.
 
이 삽입된 리스트를 거꾸로 뒤집으면 순위가 제일 낮은(즉 제일 처음 시작되는 노드) 노드부터 출력된다.
 
위상 정렬 공부를 끝낸 뒤 백준 위상정렬 문제를 풀어보았다.
 

BaekJun - #1948 : 임계경로 (위상정렬)

첫 플레티넘인 만큼 엄청나게 악명있었다.
 
임계경로 문제는 위상정렬을 반드시 사용하여 진입정점을 카운트하며 정점을 순회해야하는 문제였다.
왜냐하면 1번 정점에서 2번 정점으로 가는 길이 한가지가 아닐 수 있기 때문이다.
 
그럼 왜 1번에서 2번으로 가는길이 여러가지일 때 진입정점을 카운트해야하는가?
이유는 다음과 같다.
 
아래 코드를 보며 이해해보자

from collections import deque
import sys
import heapq
sys.setrecursionlimit(10**6)

n = int(sys.stdin.readline().strip())
m = int(sys.stdin.readline().strip())
world = {}
back_world = {}
for _ in range(m):
    v1, v2, time = map(int,sys.stdin.readline().strip().split())
    if not v1 in world:
        world[v1] = [[v2,time]]
    else :
        world[v1].append([v2,time])
    if not v2 in back_world:
        back_world[v2] = [(v1,time)]
    else:
        back_world[v2].append((v1,time))

start , end = map(int,sys.stdin.readline().strip().split())

memorize = [0 for _ in range(n+1)]
visited = [False for _ in range(n+1)]
def dfs_time(v,cost):
    if v in world:
        for next in world[v]:
            next_city , time = next
            if memorize[next_city] < time+cost:
                memorize[next_city] = time+cost
                dfs_time(next_city,cost+time)

queue = deque()
queue.append(end)

def back_bfs():
    count = 0
    while queue:
        now = queue.popleft()
        if not now in back_world:
            continue
        visited[now] = True
        for pre,cost in back_world[now]:
                if memorize[now] == memorize[pre]+cost:
                    count +=1
                    if not visited[pre]:
                        queue.append(pre)
                        visited[pre] = True
    print(count)
                
if n ==1 and m==1:
    print(world[start][1])
    print(1)
else:
    dfs_time(start,0)
    print(memorize[end])
    back_bfs()

조금 지저분하지만 dfs_time() 함수는 모든 경로를 체크해서
시작 부분에서 각 정점으로 갔을 때의 최대 시간을 memorize 리스트에 기록해주는 함수이다.
 
예를 들어 memorize[3] 은 1에서 시작해서 어떤 경로든 3으로 도착할 때 제일 오래 걸렸던 경로의 시간이 저장된다.
 
지금 dfs를 살펴보면 시작정점에서 2번정점으로 방문하고 더 깊게 재귀하고 있다.
하지만 1번 정점에서 2번정점을 방문하는 1번 정점에서 2번 정점까지의 최대 시간이다.
 
하지만 만약 1번 정점에서 2번정점 사이에서도 엄청나게 많은 길이 있다면
1번에서 2번으로 가는 경우만 dfs 재귀를 1번과 2번사이의 길만큼 해줘야 한다는 점이다.
 
그래서 1번과 2번사이의 수많은 길은 물론 모두 확인해야 하지만 그 갯수만큼 재귀하는 것이 아니라 다 체크 한 뒤
1번만 재귀하러 가야된다는 것이다.
 
그래서 나는 BFS로 다시 바꾸었고,
진입차수를 처음에 받아놓은뒤 진입차수를 줄여나가며 최대값을 갱신하고
진입차수가가 0이 될 때 append를 수행했다.
 

from collections import deque
import sys
import heapq
sys.setrecursionlimit(10**6)

n = int(sys.stdin.readline().strip())
m = int(sys.stdin.readline().strip())
world = {}
back_world = {}
indegree = [0 for _ in range(n+1)]
for _ in range(m):
    v1, v2, time = map(int,sys.stdin.readline().strip().split())
    if not v1 in world:
        world[v1] = [(v2,time)]
    else :
        world[v1].append((v2,time))
    if not v2 in back_world:
        back_world[v2] = [(v1,time)]
    else:
        back_world[v2].append((v1,time))
    indegree[v2] += 1

start , end = map(int,sys.stdin.readline().strip().split())

memorize = [0 for _ in range(n+1)]
isCounted = []
visited = [False for _ in range(n+1)]
queue = deque()
def bfs():
    queue.append(start)
    while queue:
        now = queue.popleft()
        if now in world:
            for next,cost in world[now]:
                memorize[next] = max(memorize[next],cost+memorize[now])
                indegree[next] -=1
                if indegree[next] == 0:
                    queue.append(next)




def back_bfs():
    queue.append(end)
    count = 0
    while queue:
        now = queue.popleft()
        if not now in back_world:
            continue
        visited[now] = True
        for pre,cost in back_world[now]:
                if memorize[now] == memorize[pre]+cost:
                    count +=1
                    if not visited[pre]:
                        queue.append(pre)
                        visited[pre] = True
    print(count)

그렇게 하여 시간 세이브를 하고 시간 초과를 해결하여 문제를 해결했다.

 

 


++++++++++++ 10/23일 제대로 이해한 임계경로 시간 복잡도 단축 정리

 왜 1번 코드가 2번 코드보다 늦고, 왜 1번은 시간이 초과되는데 2번 코드는 돌아가는지 완전히 이해했다.

1번 코드와 2번코드의 차이점은 진입차수가 0이 된 다음 노드만 queue에 삽입한다는 점이다.

 

이해를 돕기 위해 사진으로 설명하겠음

다음과 같은 그래프가 있고 0에서 1로 가야하는 최장거리를 찾는다고 하자

 

만약에 진입차수를 고려하지 않고 queue에 삽입을 하면 3번의 비교 수행을 해야하는 걸

진입차수가 없어지는 3만 수행하는걸로 단축된다.

 

왜 이럴까?

 

진입차수가 있는 정점들은 절대 최단 거리가 될 수 없기 때문이다.

 

3으로 돌아서 2로 가는게 1에서 바로 2로 가는것보다 훨씬 오래걸리는 경로이기때문에

1에서 2로 바로 가는 소요시간을 비교할 필요도 없다.

 

이래서 진입차수의 판단이 시간 복잡도에서 엄청 중요하구나~

이래서 위상정렬 쓰는구나~