基於PyTorch的LSTM語言模型(Language Model)中字典(Vocabulary)大小限制(例如5000以內)的基本方法

這個標題比較長,其實需求很明確:在一些最簡單的PyTorch的語言模型model中,原項目有時候並沒有提供限制Vocabulary大小的功能,但這個又是大家常見的需求,所以我用最簡單的方式總結一下:

在這裏給出的例子是可以直接運行的:

https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/02-intermediate/language_model

但是我們可以看到,其原有的data_utils.py文件裏並沒有提供限制Vocabulary大小的功能,這裏我們假定需要把Vocabulary限制在5000,下面這段代碼就可以在原基礎上實現:

import torch
import os


class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
    
    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    
    def __len__(self):
        return len(self.word2idx)


class Corpus(object):
    def __init__(self):
        self.dictionary = Dictionary()

    def get_data(self, path, batch_size=20, max_vocab_size=5000):
        
        raw_vocab={}
        special_words = ['<unk>']
        
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    if word in raw_vocab:
                        raw_vocab[word]+=1
                    else:
                        raw_vocab[word]=1
        if('<unk>' in raw_vocab):
            vocab=sorted(raw_vocab, key = lambda x: -raw_vocab.get(x))
        else:
            vocab=special_words+sorted(raw_vocab, key = lambda x: -raw_vocab.get(x))
        print('Original Vocabulary Size is %d'%len(vocab))
        if(len(vocab)>max_vocab_size):
            vocab = vocab[ : max_vocab_size]
        
        
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    if(word in vocab): 
                        self.dictionary.add_word(word)  
                    else: 
                        self.dictionary.add_word('<unk>') 
        
        print('The Generated Vocabulary Size is %d'%self.dictionary.__len__())
        
        # Tokenize the file content
        ids = torch.LongTensor(tokens)
        token = 0
        with open(path, 'r') as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    if(word in vocab):
                        ids[token] = self.dictionary.word2idx[word]
                    else:
                        ids[token] = self.dictionary.word2idx['<unk>']
                    token += 1
        
        num_batches = ids.size(0) // batch_size
        ids = ids[:num_batches*batch_size]
        return ids.view(batch_size, -1)

這裏唯一需要注意的,就是原語料庫裏可能就有<unk>,在沒有的情況下才需要加上<unk>。其他內容都很簡單,配合上面網址中的代碼和數據,即可進行測試。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章