這個標題比較長,其實需求很明確:在一些最簡單的PyTorch的語言模型model中,原項目有時候並沒有提供限制Vocabulary大小的功能,但這個又是大家常見的需求,所以我用最簡單的方式總結一下:
在這裏給出的例子是可以直接運行的:
https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/02-intermediate/language_model
但是我們可以看到,其原有的data_utils.py文件裏並沒有提供限制Vocabulary大小的功能,這裏我們假定需要把Vocabulary限制在5000,下面這段代碼就可以在原基礎上實現:
import torch
import os
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __len__(self):
return len(self.word2idx)
class Corpus(object):
def __init__(self):
self.dictionary = Dictionary()
def get_data(self, path, batch_size=20, max_vocab_size=5000):
raw_vocab={}
special_words = ['<unk>']
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
if word in raw_vocab:
raw_vocab[word]+=1
else:
raw_vocab[word]=1
if('<unk>' in raw_vocab):
vocab=sorted(raw_vocab, key = lambda x: -raw_vocab.get(x))
else:
vocab=special_words+sorted(raw_vocab, key = lambda x: -raw_vocab.get(x))
print('Original Vocabulary Size is %d'%len(vocab))
if(len(vocab)>max_vocab_size):
vocab = vocab[ : max_vocab_size]
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
if(word in vocab):
self.dictionary.add_word(word)
else:
self.dictionary.add_word('<unk>')
print('The Generated Vocabulary Size is %d'%self.dictionary.__len__())
# Tokenize the file content
ids = torch.LongTensor(tokens)
token = 0
with open(path, 'r') as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
if(word in vocab):
ids[token] = self.dictionary.word2idx[word]
else:
ids[token] = self.dictionary.word2idx['<unk>']
token += 1
num_batches = ids.size(0) // batch_size
ids = ids[:num_batches*batch_size]
return ids.view(batch_size, -1)
這裏唯一需要注意的,就是原語料庫裏可能就有<unk>,在沒有的情況下才需要加上<unk>。其他內容都很簡單,配合上面網址中的代碼和數據,即可進行測試。