Viterbi算法
(部分內容轉自知乎:《如何通俗地講解 viterbi 算法?》)
1、問題描述
如下如所示,如何快速找到從 S 到 E 的最短路徑?
一:遍歷窮舉法,可行,但速度太慢;
二:viterbi算法!
注:viterbi 維特比算法解決的是籬笆型圖的最短路徑問題,圖的節點按列組織,每列的節點數量可以不一樣,每一列的節點只能和相鄰列的節點相連,不能跨列相連,節點之間有着不同的距離,距離的值就不在圖上一一標註出來了,大家自行腦補。
2、算法分析
(1)S 到 A 列的最短路徑
首先起點是S,從S到A列的路徑有三種可能:S-A1、S-A2、S-A3,如下圖:
我們不能武斷地說S-A1、S-A2、S-A3中的哪一段必定是全局最短路徑中的一部分,目前爲止任何一段都有可能是全局最短路徑的備選項。繼續往右看,到了B列,按B列的B1、B2、B3逐個分析。
(2)S 到 B 列的最短路徑
先看 B1,經過B1的所有路徑只有3條:S-A1-B1,S-A2-B1,S-A3-B1。
這三條路徑,各節點距離加起來對比一下,就可以知道其中哪一條是最短的。假設S-A3-B1是最短的,那麼我們就知道了經過B1的所有路徑當中S-A3-B1是最短的,其它兩條路徑路徑S-A1-B1和S-A2-B1都比S-A3-B1長,絕對不是目標答案,可以大膽地刪掉了。刪掉了不可能是答案的路徑,就是viterbi算法(維特比算法)的重點,因爲後面我們再也不用考慮這些被刪掉的路徑了。現在經過B1的所有路徑只剩一條路徑了,如下圖:
接下來我們繼續看B2,同理,經過B2的路徑有3條:S-A1-B2,S-A2-B2,S-A3-B2。
這三條路徑中,各節點距離加起來對比一下,肯定也可以知道其中哪一條是最短的,其它兩條路徑路徑S-A2-B2和S-A3-B1也可以刪掉了。經過B2所有路徑只剩一條,如下圖:
接下來我們繼續看B3,同理,經過B3的路徑也有3條:S-A1-B3,S-A2-B3,S-A3-B3。
這三條路徑中我們也肯定可以算出其中哪一條是最短的,假設S-A2-B3是最短的,那麼我們就知道了經過B3的所有路徑當中S-A2-B3是最短的,其它兩條路徑路徑S-A1-B3和S-A3-B3也可以刪掉了。經過B3的所有路徑只剩一條,如下圖:
現在對於B列的所有節點我們都過了一遍,B列的每個節點我們都刪除了一些不可能是答案的路徑,刪掉這些不可能是最短路徑的情況之後,留下了三個有可能是最短的路徑:S-A3-B1、S-A1-B2、S-A2-B3。現在我們將這三條備選的路徑放在一起彙總到下圖:
(3)S 到 C 列的最短路徑
類似上面說的B列,我們從C1、C2、C3一個個節點分析。
經過C1節點的路徑有:S-A3-B1-C1、S-A1-B2-C1、S-A2-B3-C1。
和B列的做法一樣,從這三條路徑中找到最短的那條(假定是S-A3-B1-C1),其它兩條路徑同樣道理可以刪掉了。那麼經過C1的所有路徑只剩一條,如下圖:
同理,我們可以找到經過C2和C3節點的最短路徑,彙總一下:
到達C列時最終也只剩3條備選的最短路徑,我們仍然沒有足夠信息斷定哪條纔是全局最短。最後,我們繼續看E節點,才能得出最後的結論。
(4)S 到 E 的最短路徑
到E的路徑也只有3種可能性:
E點已經是終點了,我們稍微對比一下這三條路徑的總長度就能知道哪條是最短路徑了。
在效率方面相對於粗暴地遍歷所有路徑,viterbi 維特比算法到達每一列的時候都會刪除不符合最短路徑要求的路徑,大大降低時間複雜度。
(以上所有內容轉自知乎《如何通俗地講解 viterbi 算法?》,如有侵權請聯繫我刪除!)
3、python實現
上述問題只涉及節點之間的距離,這裏我們假設每個節點本身有一個狀態,節點與節點之間的距離用權重表示。爲了簡化描述和編程方便,將 S 到 A 列的權重全部置爲1,C 列到 E 的權重也全部置爲1,只考慮A、B、C三列。
用矩陣 state
表示節點的狀態,(d, n)=state.shape
,d 就表示每一層節點的數量,n 表示總層數。
用矩陣 weight
表示相鄰層之間的路徑距離,n 層就有 n-1 個權重矩陣,weight[k][i][j]
表示第 k-1 層的節點 i 到第 k 層的節點 j 之間的距離。
import numpy as np
state = [[0.9, 0.1, 0.3],
[0.1, 0.8, 0.4],
[0.0, 0.1, 0.3]]
weight = [[[0.1, 0.4, 0.5], [0.2, 0.7, 0.1], [0.9, 0.0, 0.1]],
[[0.8, 0.1, 0.1], [0.4, 0.3, 0.3], [0.1, 0.2, 0.7]]]
def viterbi(state, weight):
'''
:param state: 狀態矩陣
:param weight: 權重矩陣
:return:
'''
state = np.array(state)
weight = np.array(weight)
d, n = state.shape
assert weight.shape == (n - 1, d, d), 'state not match path!'
# 路徑矩陣,元素值表示當前節點從前一層的那一個節點過來是最優的
path = np.zeros(shape=(d, n))
for i in range(n):
print(f'進入第 {i} 層')
if i == 0:
path[:, i] = np.array(range(d)) + 1
print('')
continue
for j in range(d):
print(f'更新節點 ({j}, {i}) 的狀態')
temp = state[:, i - 1] * weight[i - 1, :, j]
temp_max = max(temp)
temp_index = np.where(temp == temp_max)
path[j, i] = temp_index[0] + 1
state[j, i] = max(temp) * state[j, i]
print('')
print(state)
print(path)
if __name__ == '__main__':
viterbi(state, weight)