維特比算法以及解碼時的beamSearch

維特比算法

輸入序列爲詞,輸出序列爲POS,採用HMM爲例介紹維特比算法
HMM假設當前的隱含狀態(POS TAG)只和上一時刻的隱含狀態相關,當前的observation(詞)是由隱含狀態生成的,之和當前的隱含狀態相關。
假設S(k,i,j)是一個集合,集合中每個元素是一個長度爲k的序列,且每個序列的最後兩個元素爲 (i, j).
定義$\pi (t,i,j) $ 代表集合中概率最大的那個序列的概率值,其滿足下面的迭代規則
π(t,i,j)=max(π(t1,i)×q(ji)×e(xtj))\pi (t,i,j) = max(\pi (t-1, i)\times q(j|i)\times e(x_{t}|j))
i 代表t-1時刻的輸出標籤;j代表t時刻的輸出標籤;e表示當前時刻爲標籤j時,obsevation 爲x的概率,是一個後驗概率.

最大化聯合概率相當與
p(x1,x2...xn,y1,y1...yn,STOP)=maxu,v(π(n,u,v)×q(STOPU,V))p\left ( x_{1},x_{2}...x_{n},y_{1},y_{1}...y_{n},STOP \right ) = max_{u,v}\left ( \pi \left ( n,u,v \right )\times q\left ( STOP|U,V \right ) \right )

計算時用一個二維數組來記住每個t時刻取每個j時,相應的t-1時刻應該取的值,以便回溯。

implement

/*
o代表輸入序列,
delta代表上面的pi,
psi:二維數組,第一維用時間做索引,第二維用標籤做索引,
    裏面的值表示當前時刻取第二維索引的值時上一時刻應該取的標籤。代表從後向前的指針索引
q: 代表解碼結果序列,
pprob:代表輸出序列的聯合概率
*/
void Viterbi(HMM *phmm, int T, int *O, double **delta, int **psi, 
	int *q, double *pprob)
{
	int 	i, j;	/* state indices */
	int  	t;	/* time index */	

	int	maxvalind;
	double	maxval, val;

	/* 1. Initialization 初始化t=1時刻的值 
	phmm->N代表詞表的大小; phmm->B[i][O[1]]表示上面公式裏的e
    phmm->pi[i]:表示0時刻的標籤爲i的概率,是一個預設的初始值。0時刻表示句子的起始符,
    句子從1纔是真正的詞
    */
	for (i = 1; i <= phmm->N; i++) {
		delta[1][i] = phmm->pi[i] * (phmm->B[i][O[1]]);
		psi[1][i] = 0;
	}	

	/* 2. Recursion */
	for (t = 2; t <= T; t++) {
	/*
	下面j代表t時刻輸出j, i代表t-1時刻輸出i。
	*/
		for (j = 1; j <= phmm->N; j++) {
		/*
		對t時刻每個j都計算一個pi值,計算時假設t-1時刻的所有pi都是已知的,即delta[t-1][i]已知
		*/
			maxval = 0.0;
			maxvalind = 1;	
			for (i = 1; i <= phmm->N; i++) {
				/*A[i][j]代表轉移概率,即t-1時刻輸出i是,t時刻輸出j的概率,對應 上面公式的q*/
				val = delta[t-1][i]*(phmm->A[i][j]);
				if (val > maxval) {
					maxval = val;	
					maxvalind = i;	
				}
			}
			/*這裏記錄t時刻取j的pi值. B[j][O[t]]對應上面公式的e,即當前時刻被標記爲j時,得到當前輸入的概率*/
			delta[t][j] = maxval*(phmm->B[j][O[t]]);
			/*t時刻要生成j則前面第t-1時刻要輸出maxvalind,psi在這裏其實相當於一個指針指向之前的輸出*/
			psi[t][j] = maxvalind; 

		}
	}

	/* 3. Termination */
	/*全部計算完後,在T時刻計算取delta的最大值時對應的j*/
	*pprob = 0.0;
	q[T] = 1;
	for (i = 1; i <= phmm->N; i++) {
                if (delta[T][i] > *pprob) {
			*pprob = delta[T][i];	
			q[T] = i;
		}
	}

	/* 4. Path (state sequence) backtracking */
	/*psi二維數組裏,第t+1時刻裏的第q[t+1]個元素指明瞭t時刻應該輸出什麼值*/
	for (t = T - 1; t >= 1; t--)
		q[t] = psi[t+1][q[t+1]];

}

beam Search

參考: https://www.zhihu.com/question/54356960

beam search只在test的時候需要。訓練的時候知道正確答案,並不需要再進行這個搜索。

test的時候,假設詞表大小爲3,內容爲a,b,c。beam size是2

decoder解碼的時候:

1: 生成第1個詞的時候,選擇概率最大的2個詞,假設爲a,c,那麼當前序列就是a,c

2:生成第2個詞的時候,我們將當前序列a和c,分別與詞表中的所有詞進行組合,得到新的6個序列aa ab ac ca cb cc,然後從其中選擇2個得分最高的,作爲當前序列,假如爲aa cb

3:後面會不斷重複這個過程,直到遇到結束符爲止。最終輸出2個得分最高的序列。


THU summarise beam search

預測時,先對源句子encoder,得到每個step的輸出和最後一個step的隱藏狀態

    def step_beam(self,
                  session,
                  encoder_inputs,
                  encoder_len,
                  max_len=12,
                  geneos=True):
		'''
		encoder_len: [beam_size, ]
		encoder_inputs : [beam_size, max(encoder_len)],代表輸入的一個句子複製了beam次後的結果,
		所以每次解碼的是一個句子,而不是一個batch的句子
		'''
        beam_size = self.batch_size
		#如果輸入不是batch,則將其變成batch
        if encoder_inputs.shape[0] == 1:
            encoder_inputs = np.repeat(encoder_inputs, beam_size, axis=0)
            encoder_len = np.repeat(encoder_len, beam_size, axis=0)

        if encoder_inputs.shape[1] != max(encoder_len):
            raise ValueError("encoder_inputs and encoder_len does not fit")
        
        #將源句子輸入到encoder,進行編碼
        input_feed = {}
        input_feed[self.encoder_inputs] = encoder_inputs
        input_feed[self.encoder_len] = encoder_len
        #self.att_states :保存encoder時每個step的輸出
        #self.init_state :保存encoder時最後一個step的輸出隱藏狀態
        output_feed = [self.att_states, self.init_state]
        outputs = session.run(output_feed, input_feed)

        att_states = outputs[0]
        #[batch, hiddensize]
        prev_state = outputs[1]

		'''
		開始解碼decoder
		'''
		#解碼第一個step對應的輸入爲ID_GO
        prev_tok = np.ones([beam_size], dtype="int32") * data_util.ID_GO

        input_feed = {}
        input_feed[self.att_states] = att_states
        input_feed[self.encoder_len] = encoder_len
        input_feed[self.generate_len] = 0 #Generate only 1 word

        ret = [[]] * beam_size #用來保存最終的beam_size個解碼序列結果
        neos = np.ones([beam_size], dtype="bool")

        score = np.ones([beam_size], dtype="float32") * (-1e8)
        score[0] = 0
		#上一步的context數組,也就是attentin數組
        attention_prev = np.zeros(
            [self.batch_size, self.state_size], dtype="float32")

        for i in range(max_len):
			#每次計算時,圖中的變量值會全部更新,所以在判斷decoder_fn
			#裏的if cell_output is None時,這個地方爲真
            input_feed[self.init_state] = prev_state #上一個解碼step的hidden state
            input_feed[self.previous_tok] = prev_tok #上一個解碼step的預測詞
            input_feed[self.attention_prev] = attention_prev #上一個解碼step的context
            #final_context_state用來保存這個step會計算得到的context數組
            output_feed = [self.final_context_state,
                           self.outputs_logsoftmax,
                           self.final_state]

            outputs = session.run(output_feed, input_feed)

            attention_prev = outputs[0]
            tok_logsoftmax = np.asarray(outputs[1]) #[beam_size, self.target_vocab_size]
            #beam_size 在函數開頭設置成了batchsize
            tok_logsoftmax = tok_logsoftmax.reshape(
                [beam_size, self.target_vocab_size])
            if not geneos:
                tok_logsoftmax[:, data_util.ID_EOS] = -1e8

            #截取每一行裏取值最大的beam_size個值所對應的索引,按照從
            #小到大的順序排列,這樣tok_argsort的shape爲
            #[beam_size,beam_size],注意這裏每一行取beam_size,而不是整
            #個矩陣裏取beam_size
            tok_argsort = np.argsort(tok_logsoftmax, axis=1)[:, -beam_size:]
            tmp_arg0 = np.arange(beam_size).reshape([beam_size, 1])
            #根據前面得到的索引從tok_logsoftmax拿出具體的分數值,返回的shape爲[beamsize,beamsize]
            tok_argsort_score = tok_logsoftmax[tmp_arg0, tok_argsort]
            
            #neos是一個bool數組,最開始時都是true,當beam裏的某個序列遇到了
            #EOS符號時,代表這個序列解碼結束,這時noes[i]會設置爲false
            tok_argsort_score *= neos.reshape([beam_size, 1])

            #每個分數要和之前的分數相加(本來應該是乘號,但因爲取了Log,
            #所以變成加號)
            tok_argsort_score += score.reshape([beam_size, 1])

            #得到整個矩陣裏取值最大的beam個字符對應的index,從小到大
            all_arg = np.argsort(tok_argsort_score.flatten())[-beam_size:]
            #arg0 對應的是beam也可以說是行號,arg1 代表列號
            arg0 = all_arg // beam_size #previous id in batch
            arg1 = all_arg % beam_size
			#prev_tok是一個數組,代表本次step得到的beam size個解碼符號
            prev_tok = tok_argsort[arg0, arg1] #current word
            prev_state = outputs[2][arg0]
            score = tok_argsort_score[arg0, arg1]
			
			#用來記住這個step裏某個序列是否已經到達了end標誌,爲false時表示一個序列解碼結束
            neos = neos[arg0] & (prev_tok != data_util.ID_EOS)

			#將這次得到的beam個字符連接到之前的beam裏得到新的beam個列表
			#ret代表之前的解碼結果
			#prev_tok代表這個step解碼的beam個結果
            ret_t = []
            for j in range(beam_size):
                ret_t.append(ret[arg0[j]] + [prev_tok[j]])

            ret = ret_t
        #取解碼結果裏對應取值最大的那個序列作爲輸出,
        #列表最後一個序列就是最大的那個序列
        return ret[-1]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章