GLMP 代碼 詳細註釋

GLMP 縮寫來自論文 GLOBAL-TO-LOCAL MEMORY POINTER NETWORKS
FOR TASK-ORIENTED DIALOGUE

下面是它代碼的詳細註釋(已跑通)

3.1 模型

3.1.1 ContextRNN

class ContextRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout, n_layers=1):
        #初始化設置參數
		super(ContextRNN, self).__init__()      
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers     
        self.dropout = dropout
        #nn.Dropout:參數爲float類型,將元素置0的概率
        self.dropout_layer = nn.Dropout(dropout)
        #nn.Embedding:參數分別爲(單詞個數 詞向量維度 遇到PAD_token輸出0)
        #此處可看出embedding詞向量的維度和hidden的維度相同
        self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=PAD_token)
		#調用pytorch中的GRU模塊,設置網絡爲雙向GRU
        self.gru = nn.GRU(hidden_size, hidden_size,
                          n_layers, dropout=dropout, bidirectional=True)
        self.W = nn.Linear(2*hidden_size, hidden_size)

    def get_state(self, bsz):
        """Get cell states and hidden states."""
        return _cuda(torch.zeros(2, bsz, self.hidden_size))

    def forward(self, input_seqs, input_lengths, hidden=None):
        # Note: we run this all at once (over multiple batches of multiple sequences)
        #contiguous函數返回一個內存連續的tensor
        #view函數返回一個tensor,必須有與原tensor相同的數據和相同數目的元素,但可以有不同的大小。
        #一個tensor必須是連續的contiguous()才能被查看。
        #兩個函數聯合作用將embedding的維度調整成一句一行
        embedded = self.embedding(input_seqs.contiguous().view(input_seqs.size(0),
                                                               -1).long()) 
        embedded = embedded.view(input_seqs.size()+(embedded.size(-1),))
        embedded = torch.sum(embedded, 2).squeeze(2) 
        embedded = self.dropout_layer(embedded)
        #初始化hidden
        hidden = self.get_state(input_seqs.size(1))
        if input_lengths:
            embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=False)
        outputs, hidden = self.gru(embedded, hidden)
        if input_lengths:
           outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=False)   
        hidden = self.W(torch.cat((hidden[0], hidden[1]), dim=1)).unsqueeze(0)
        outputs = self.W(outputs)
        return outputs.transpose(0,1), hidden

3.1.2 ExternalKnowledge

class ExternalKnowledge(nn.Module):
    def __init__(self, vocab, embedding_dim, hop, dropout):
        #ExternalKnowledge的工作原理類似於MemoryNetwork
        super(ExternalKnowledge, self).__init__()
        self.max_hops = hop#跳數
        self.embedding_dim = embedding_dim
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(dropout) 
        for hop in range(self.max_hops+1):#針對每一跳都初始化
            #nn.Embedding:單詞個數 詞向量維度 遇到PAD_token輸出0
            C = nn.Embedding(vocab, embedding_dim, padding_idx=PAD_token)
            #將值用均值爲0,方差爲0.1的正態分佈填充
            C.weight.data.normal_(0, 0.1)
            ##將一個child module添加到當前model,被添加的module可以通過name屬性來獲取。
            self.add_module("C_{}".format(hop), C)
        #定義查詢C的方法?通過C_{i}來查詢
        self.C = AttrProxy(self, "C_")
        #定義softmax爲單維度的softmax
        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        #一維卷積函數Conv1d:參數分別爲(進 出通道 卷積核的大小 輸入的每一條邊補充0的層數)
        self.conv_layer = nn.Conv1d(embedding_dim, embedding_dim, 5, padding=2)

    def add_lm_embedding(self, full_memory, kb_len, conv_len, hiddens):
        #將hiddens按照kb_len的順序加入full_memory矩陣
        for bi in range(full_memory.size(0)):
            start, end = kb_len[bi], kb_len[bi]+conv_len[bi]
            full_memory[bi, start:end, :] = full_memory[bi, start:end, :] + hiddens[bi, :conv_len[bi], :]
        return full_memory

    def load_memory(self, story, kb_len, conv_len, hidden, dh_outputs):
        # Forward multiple hop mechanism
        #squeeze函數:把第一個維度(維度爲1)擠掉
        u = [hidden.squeeze(0)]
        story_size = story.size()
        self.m_story = []
        for hop in range(self.max_hops):#循環K跳來計算Attention權重
            #求c^k_i
            #把外部知識的三元組擠成一個向量
			embed_A = self.C[hop](story.contiguous().view(story_size[0], -1))#.long()) # b * (m * s) * e
            embed_A = embed_A.view(story_size+(embed_A.size(-1),)) # b * m * s * e
            #sum函數:返回輸入張量給定維度上每行的和
            embed_A = torch.sum(embed_A, 2).squeeze(2) # b * m * e
            if not args["ablationH"]:
                embed_A = self.add_lm_embedding(embed_A, kb_len, conv_len, dh_outputs)
            embed_A = self.dropout_layer(embed_A)
            
            if(len(list(u[-1].size()))==1): 
                u[-1] = u[-1].unsqueeze(0) ## used for bsz = 1.
            #調整q^k到可以與C_i相乘的維度
            u_temp = u[-1].unsqueeze(1).expand_as(embed_A)
            #將結果調整成可以進行單維度softmax的向量
            prob_logit = torch.sum(embed_A*u_temp, 2)
            #按照論文所給公式求出p^k_i
            prob_   = self.softmax(prob_logit)
            
            #重複求c^k_i的步驟求出c^k+1_i
            embed_C = self.C[hop+1](story.contiguous().view(story_size[0], -1).long())
            embed_C = embed_C.view(story_size+(embed_C.size(-1),)) 
            embed_C = torch.sum(embed_C, 2).squeeze(2)
            if not args["ablationH"]:
                embed_C = self.add_lm_embedding(embed_C, kb_len, conv_len, dh_outputs)

            #調整p的維度來與C^k+1_i相乘
            prob = prob_.unsqueeze(2).expand_as(embed_C)
            #求和得到o^k
            o_k  = torch.sum(embed_C*prob, 1)
            #q^k+1 = q^k + o^k
            u_k = u[-1] + o_k
            u.append(u_k)
            self.m_story.append(embed_A)
        self.m_story.append(embed_C)
        #返回p^k(?)和q^k+1
        return self.sigmoid(prob_logit), u[-1]

    def forward(self, query_vector, global_pointer):
        #U爲查詢向量
        u = [query_vector]
        #循環k跳來得出最後的查詢結果
        for hop in range(self.max_hops):
            #從m_story中取出load_memory中存好的c^k
            m_A = self.m_story[hop] 
            if not args["ablationG"]:
                #global_pointer更新Global contextual representation 
                m_A = m_A * global_pointer.unsqueeze(2).expand_as(m_A) 
            if(len(list(u[-1].size()))==1): 
                u[-1] = u[-1].unsqueeze(0) ## used for bsz = 1.
            u_temp = u[-1].unsqueeze(1).expand_as(m_A)
            prob_logits = torch.sum(m_A*u_temp, 2)
            prob_soft   = self.softmax(prob_logits)
            m_C = self.m_story[hop+1] 
            if not args["ablationG"]:
                m_C = m_C * global_pointer.unsqueeze(2).expand_as(m_C)
            prob = prob_soft.unsqueeze(2).expand_as(m_C)
            o_k  = torch.sum(m_C*prob, 1)
            u_k = u[-1] + o_k
            u.append(u_k)
        return prob_soft, prob_logits

3.1.3 LocalMemoryDecoder

class LocalMemoryDecoder(nn.Module):
    def __init__(self, shared_emb, lang, embedding_dim, hop, dropout):
        #初始化網絡
        super(LocalMemoryDecoder, self).__init__()
        self.num_vocab = lang.n_words
        self.lang = lang
        self.max_hops = hop
        self.embedding_dim = embedding_dim
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(dropout) 
        #將shared_emb保存爲C(此C不同於ExternalKnowledge中的self.C)
        #根據GLMP中代碼得出shared_emb爲encoder的embedding
        self.C = shared_emb 
        self.softmax = nn.Softmax(dim=1)
		#sketch RNN用於跑出沒有槽值信息但是有sketch tag的response
        self.sketch_rnn = nn.GRU(embedding_dim, embedding_dim, dropout=dropout)
        self.relu = nn.ReLU()
        self.projector = nn.Linear(2*embedding_dim, embedding_dim)
        self.conv_layer = nn.Conv1d(embedding_dim, embedding_dim, 5, padding=2)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, extKnow, story_size, story_lengths, copy_list, encode_hidden, target_batches, max_target_length, batch_size, use_teacher_forcing, get_decoded_words, global_pointer):
        # Initialize variables for vocab and pointer
        #初始化輸入輸出矩陣
        all_decoder_outputs_vocab = _cuda(torch.zeros(max_target_length, batch_size, self.num_vocab))
        all_decoder_outputs_ptr = _cuda(torch.zeros(max_target_length, batch_size, story_size[1]))
        decoder_input = _cuda(torch.LongTensor([SOS_token] * batch_size))
        #mask矩陣用來防止生成相同的槽
        memory_mask_for_step = _cuda(torch.ones(story_size[0], story_size[1]))
        decoded_fine, decoded_coarse = [], []
        
        hidden = self.relu(self.projector(encode_hidden)).unsqueeze(0)
        
        # Start to generate word-by-word
        for t in range(max_target_length):
            #hidden的生成在前四行,不同循環的不同變量只有decoder_input
            embed_q = self.dropout_layer(self.C(decoder_input)) # b * e
            if len(embed_q.size()) == 1: embed_q = embed_q.unsqueeze(0)
            _, hidden = self.sketch_rnn(embed_q.unsqueeze(0), hidden)
            #取sketch_RNN 的第一個hidden
            query_vector = hidden[0] 
            
            #求p^vocab
            p_vocab = self.attend_vocab(self.C.weight, hidden.squeeze(0))
            all_decoder_outputs_vocab[t] = p_vocab
            #topk函數:得到前k個元素,返回兩個tensor,第一個爲數值,第二個爲下標
            #此處是得到數值最大的元素的下標
            #所以過不過softmax意義不大,這可能是代碼把softmax註釋掉的原因
            _, topvi = p_vocab.data.topk(1)
            
            # query the external konwledge using the hidden state of sketch RNN
            #通過sketch RNN的hidden來向外部知識查詢
            #extKnow爲ExternalKnowledge的forward函數
            prob_soft, prob_logits = extKnow(query_vector, global_pointer)
            #得到L_t:本地內存指針的位置標誌
            all_decoder_outputs_ptr[t] = prob_logits

            if use_teacher_forcing:#是否使用標準答案來改變輸入以改變sketchRNN生成的hidden
                #使用預先存入的生成
                decoder_input = target_batches[:,t] 
            else:
                #使用sketchRNN上次生成的output
                decoder_input = topvi.squeeze()
            
            if get_decoded_words:

                search_len = min(5, min(story_lengths))
                prob_soft = prob_soft * memory_mask_for_step
                #取前search_len元素,作爲填入槽的object預備
                _, toppi = prob_soft.data.topk(search_len)
                temp_f, temp_c = [], []
                
                for bi in range(batch_size):
                    token = topvi[bi].item() #topvi[:,0][bi].item()#取下標
                    temp_c.append(self.lang.index2word[token])#取單詞
                    
                    if '@' in self.lang.index2word[token]:#判斷是否是槽
                        #如果是槽且符合條件,那麼可以將代表的賓語加入輸出
                        cw = 'UNK'
                        for i in range(search_len):
                            if toppi[:,i][bi] < story_lengths[bi]-1: 
                                cw = copy_list[bi][toppi[:,i][bi].item()]            
                                break
                        temp_f.append(cw)
                        
                        if args['record']:
                            #mask矩陣標0,防止生成相同的槽
                            memory_mask_for_step[bi, toppi[:,i][bi].item()] = 0
                    else:
                        #不是槽的話就直接輸出詞語
                        temp_f.append(self.lang.index2word[token])

                decoded_fine.append(temp_f)
                decoded_coarse.append(temp_c)

        return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse

    def attend_vocab(self, seq, cond):
        scores_ = cond.matmul(seq.transpose(1,0))#首先對輸入的矩陣轉置,然後進行矩陣乘法
        #論文中寫的公式是帶softmax的……但註釋掉了就很秀
        # scores = F.softmax(scores_, dim=1)
        return scores_

3.2 GLMP

3.2.1 Encoder&Decoder

def encode_and_decode(self, data, max_target_length, use_teacher_forcing, get_decoded_words):
        # Build unknown mask for memory 
        #初始化mask矩陣
        if args['unk_mask'] and self.decoder.training:
            story_size = data['context_arr'].size()
            rand_mask = np.ones(story_size)
            bi_mask = np.random.binomial([np.ones((story_size[0],story_size[1]))], 1-self.dropout)[0]
            rand_mask[:,:,0] = rand_mask[:,:,0] * bi_mask
            conv_rand_mask = np.ones(data['conv_arr'].size())
            for bi in range(story_size[0]):
                start, end = data['kb_arr_lengths'][bi],  data['kb_arr_lengths'][bi] + data['conv_arr_lengths'][bi]
                conv_rand_mask[:end-start,bi,:] = rand_mask[bi,start:end,:]
            rand_mask = self._cuda(rand_mask)
            conv_rand_mask = self._cuda(conv_rand_mask)
            conv_story = data['conv_arr'] * conv_rand_mask.long()
            story = data['context_arr'] * rand_mask.long()
        else:
            story, conv_story = data['context_arr'], data['conv_arr']
        
        # Encode dialog history and KB to vectors
        #encoder爲modules中的ContextRNN,extKnow爲modules中的ExternalKonwledge
        dh_outputs, dh_hidden = self.encoder(conv_story, data['conv_arr_lengths'])
        global_pointer, kb_readout = self.extKnow.load_memory(story, data['kb_arr_lengths'], data['conv_arr_lengths'], dh_hidden, dh_outputs)
        #cat函數會在給定維度上對輸入的張量序列進行連接操作。
        #這裏將kb_memory和對話歷史的信息聯合到一起,組成外部知識
        encoded_hidden = torch.cat((dh_hidden.squeeze(0), kb_readout), dim=1) 
        
        # Get the words that can be copy from the memory
        #準備kbmemory中可以拷貝的賓語列表
        batch_size = len(data['context_arr_lengths'])
        self.copy_list = []
        for elm in data['context_arr_plain']:
            elm_temp = [ word_arr[0] for word_arr in elm ]
            self.copy_list.append(elm_temp) 
        
        #decoder(LocalMemoryDecoder)生成output語句
        outputs_vocab, outputs_ptr, decoded_fine, decoded_coarse = self.decoder.forward(
            self.extKnow, 
            story.size(), 
            data['context_arr_lengths'],
            self.copy_list, 
            encoded_hidden, 
            data['sketch_response'], 
            max_target_length, 
            batch_size, 
            use_teacher_forcing, 
            get_decoded_words, 
            global_pointer) 

        return outputs_vocab, outputs_ptr, decoded_fine, decoded_coarse, global_pointer

3.2.2 Evaluate

def evaluate(self, dev, matric_best, early_stop=None):
        print("STARTING EVALUATION")
        # Set to not-training mode to disable dropout
        #因爲train函數的默認爲True,即開啓dropout
        self.encoder.train(False)
        self.extKnow.train(False)
        self.decoder.train(False)  
        
        ref, hyp = [], []
        acc, total = 0, 0
        dialog_acc_dict = {}
        F1_pred, F1_cal_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0
        F1_count, F1_cal_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0
        pbar = tqdm(enumerate(dev),total=len(dev))
        new_precision, new_recall, new_f1_score = 0, 0, 0

        #讀入數據
        if args['dataset'] == 'kvr':
            with open('data/KVR/kvret_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity.keys():
                    if key != 'poi':
                        global_entity_list += [item.lower().replace(' ', '_') for item in global_entity[key]]
                    else:
                        for item in global_entity['poi']:
                            global_entity_list += [item[k].lower().replace(' ', '_') for k in item.keys()]
                global_entity_list = list(set(global_entity_list))

        for j, data_dev in pbar: 
            # Encode and Decode
            _, _, decoded_fine, decoded_coarse, global_pointer = self.encode_and_decode(data_dev, self.max_resp_len, False, True)
            decoded_coarse = np.transpose(decoded_coarse)
            decoded_fine = np.transpose(decoded_fine)
            for bi, row in enumerate(decoded_fine):#各種計算
                st = ''
                for e in row:
                    if e == 'EOS': break
                    else: st += e + ' '
                st_c = ''
                for e in decoded_coarse[bi]:
                    if e == 'EOS': break
                    else: st_c += e + ' '
                pred_sent = st.lstrip().rstrip()
                pred_sent_coarse = st_c.lstrip().rstrip()
                gold_sent = data_dev['response_plain'][bi].lstrip().rstrip()
                ref.append(gold_sent)
                hyp.append(pred_sent)
                
                if args['dataset'] == 'kvr': 
                    # compute F1 SCORE
                    #計算F-評論的結果
                    single_f1, count = self.compute_prf(data_dev['ent_index'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_pred += single_f1
                    F1_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_cal'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_cal_pred += single_f1
                    F1_cal_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_nav'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_nav_pred += single_f1
                    F1_nav_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_wet'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_wet_pred += single_f1
                    F1_wet_count += count
                else:
                    # compute Dialogue Accuracy Score
                    #計算對話準確性結果
                    current_id = data_dev['ID'][bi]
                    if current_id not in dialog_acc_dict.keys():
                        dialog_acc_dict[current_id] = []
                    if gold_sent == pred_sent:
                        dialog_acc_dict[current_id].append(1)
                    else:
                        dialog_acc_dict[current_id].append(0)

                # compute Per-response Accuracy Score
                #計算每個迴應的準確性
                total += 1
                if (gold_sent == pred_sent):
                    acc += 1

                if args['genSample']:
                    self.print_examples(bi, data_dev, pred_sent, pred_sent_coarse, gold_sent)

        # Set back to training mode
        #開啓dropput防止過擬合
        self.encoder.train(True)
        self.extKnow.train(True)
        self.decoder.train(True)

        bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True)
        acc_score = acc / float(total)
        print("ACC SCORE:\t"+str(acc_score))

        if args['dataset'] == 'kvr':
            F1_score = F1_pred / float(F1_count)
            print("F1 SCORE:\t{}".format(F1_pred/float(F1_count)))
            print("\tCAL F1:\t{}".format(F1_cal_pred/float(F1_cal_count))) 
            print("\tWET F1:\t{}".format(F1_wet_pred/float(F1_wet_count))) 
            print("\tNAV F1:\t{}".format(F1_nav_pred/float(F1_nav_count))) 
            print("BLEU SCORE:\t"+str(bleu_score))
        else:
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k])==sum(dialog_acc_dict[k]):
                    dia_acc += 1
            print("Dialog Accuracy:\t"+str(dia_acc*1.0/len(dialog_acc_dict.keys())))
        
        if (early_stop == 'BLEU'):
            if (bleu_score >= matric_best):
                self.save_model('BLEU-'+str(bleu_score))
                print("MODEL SAVED")
            return bleu_score
        elif (early_stop == 'ENTF1'):
            if (F1_score >= matric_best):
                self.save_model('ENTF1-{:.4f}'.format(F1_score))
                print("MODEL SAVED")  
            return F1_score
        else:
            if (acc_score >= matric_best):
                self.save_model('ACC-{:.4f}'.format(acc_score))
                print("MODEL SAVED")
            return acc_score

3.2.3 Compute_prf

def compute_prf(self, gold, pred, global_entity_list, kb_plain):
        local_kb_word = [k[0] for k in kb_plain]
        TP, FP, FN = 0, 0, 0
        if len(gold)!= 0:
            count = 1
            for g in gold:
                if g in pred:
                    TP += 1
                else:
                    FN += 1
            for p in set(pred):
                if p in global_entity_list or p in local_kb_word:
                    if p not in gold:
                        FP += 1
            #計算準確率
            precision = TP / float(TP+FP) if (TP+FP)!=0 else 0
            #計算召回率
            recall = TP / float(TP+FN) if (TP+FN)!=0 else 0
            #F-評價,綜合準確率和召回率的評價指標
            F1 = 2 * precision * recall / float(precision + recall) if (precision+recall)!=0 else 0
        else:
            precision, recall, F1, count = 0, 0, 0, 0
        return F1, count
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章