數據集
Omniglot
- 包含50個字母表的1623個手寫字符,每個字符包含20個樣本
- 先調整尺寸到28x28,之後通過多次旋轉90度的方式增加字符的種類,一共6492類
- 劃分
- 訓練集:82240項 4112類
- 驗證集:13760項 688類
- 測試集:33840項 1692類
Mini-ImageNet
- 從ImageNet中隨機選取100個類,每類包含600個樣本
- 將尺寸縮放到84x84
- 包含
- 訓練集:64類
- 驗證集:16類
- 測試集:20類
數據準備
每個iteration包含多個batch,也就是多個eposide;每個eposide包含隨機的classes_per_it個類別,每個類別包含隨機選擇的sample_per_class個樣本組成support set,query set由這些類中的一個隨機類的一個隨機樣本組成。由於這些樣本是作爲一個序列輸入到模型中的,所以最後一個樣本即爲query set,也就是要預測標籤的樣本。輸入時,將一個batch中的所有eposide的樣本拼接起來一起輸入。
模型
將圖像輸入到時序卷積網絡前,先要對圖像做特徵提取
特徵提取
- Omniglot:使用和PrototpicalNet相同的結構
- Mini-ImageNet:在PrototpicalNet中,使用的是和Omniglot相同的結構,通道數減少到32,但是這樣淺層的特徵提取網絡沒有充分的利用SNAIL的容量,所以使用了ResNet進行特徵提取
時序卷積
- 時序卷積是通過在時間維度上膨脹的一維卷積生成時序數據的結構,如下圖所示。這種時序卷積是因果的,所以在下一個時間節點生成的值只會被之前時間節點的信息影響,而不受未來信息的影響。相比較於傳統的RNN,它提供了一種更直接,更高帶寬的方式來獲取過去的信息。但是,爲了處理更長的序列,膨脹率通常是指數級增長的,所以需要的卷積層數和序列長度呈對數關係。因此,只能對很久之前的信息進行粗略的訪問,有限的容量和位置依賴性對於元學習方法是不利的,不能充分利用大量的先前的經驗。
class CasualConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, dilation=1, groups=1, bias=True):
super(CasualConv1d, self).__init__()
self.dilation = dilation
padding = dilation * (kernel_size - 1)
self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, input):
# Takes something of shape (N, in_channels, T),
# returns (N, out_channels, T)
out = self.conv1d(input)
return out[:, :, :-self.dilation] #
-
dilation爲膨脹率(如上圖所示,也就是卷積核元素之間的距離),T爲要處理的序列長度,卷積核大小爲2
class DenseBlock(nn.Module):
def __init__(self, in_channels, dilation, filters, kernel_size=2):
super(DenseBlock, self).__init__()
self.casualconv1 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)
self.casualconv2 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)
def forward(self, input):
# input is dimensions (N, in_channels, T)
xf = self.casualconv1(input)
xg = self.casualconv2(input)
activations = F.tanh(xf) * F.sigmoid(xg) # shape: (N, filters, T)
return torch.cat((input, activations), dim=1)
- 爲了提高模型的效果,作者使用了殘差連接和稠密連接。一個denseblock包含一個膨脹率爲R卷積核數爲D的一維因果卷積,使用了geted的激活函數,最後將輸出與輸入進行拼接。
class TCBlock(nn.Module):
def __init__(self, in_channels, seq_length, filters):
super(TCBlock, self).__init__()
self.dense_blocks = nn.ModuleList([DenseBlock(in_channels + i * filters, 2 ** (i+1), filters) for i in range(int(math.ceil(math.log(seq_length))))])
def forward(self, input):
# input is dimensions (N, T, in_channels)
input = torch.transpose(input, 1, 2)
for block in self.dense_blocks:
input = block(input)
return torch.transpose(input, 1, 2)
- 整個的時序卷積網絡是由一系列的denseblock組成,每個denseblock膨脹率呈指數增加,直到感受野包含整個序列。
注意力模塊
soft attention可以讓模型在可能的無限大的上下文中精確的定位信息,把上下文信息當做無序的鍵值對,通過內容對其進行查找。
class AttentionBlock(nn.Module):
def __init__(self, in_channels, key_size, value_size):
super(AttentionBlock, self).__init__()
self.linear_query = nn.Linear(in_channels, key_size)
self.linear_keys = nn.Linear(in_channels, key_size)
self.linear_values = nn.Linear(in_channels, value_size)
self.sqrt_key_size = math.sqrt(key_size)
def forward(self, input):
# input is dim (N, T, in_channels) where N is the batch_size, and T is
# the sequence length
mask = np.array([[1 if i>j else 0 for i in range(input.shape[1])] for j in range(input.shape[1])])
mask = torch.ByteTensor(mask).cuda()
#import pdb; pdb.set_trace()
keys = self.linear_keys(input) # shape: (N, T, key_size)
query = self.linear_query(input) # shape: (N, T, key_size)
values = self.linear_values(input) # shape: (N, T, value_size)
temp = torch.bmm(query, torch.transpose(keys, 1, 2)) # shape: (N, T, T)
temp.data.masked_fill_(mask, -float('inf'))
temp = F.softmax(temp / self.sqrt_key_size, dim=1) # shape: (N, T, T), broadcasting over any slice [:, x, :], each row of the matrix
temp = torch.bmm(temp, values) # shape: (N, T, value_size)
return torch.cat((input, temp), dim=2) # shape: (N, T, in_channels + value_size)
- 基於self attention,使用鍵值查詢的方式對之前的信息進行訪問,爲了保證在特定的時間節點不能訪問未來的鍵值對,在softmax之前加入了mask,把query與未來的key之間的匹配度設置爲負無窮,最後將輸出與輸入進行拼接。
SNAIL
- 時序卷積可以在有限的上下文中提供高帶寬的訪問方式,attention可以在很大的上下文中精確地訪問信息,所以將二者結合寄來就得到了SNAIL。在時序卷積產生的上下文中應用causal attention,可以使網絡學習到挑出聚集到的哪些信息,以及如何更好地表示這些信息。SNAIL由兩個卷積和attention交錯組成。
- 對於N-way,K-shot的問題,輸入序列的長度爲N*K+1
- 由[192,1,28,28]-encoder->[192,64]-cat->[192,69]->[32,6,69]-AttentionBlock->[32,6,101]-TCBlock->[32,6,357]-AttentionBlock->[32,6,485]-TCBlock->[32,6,741]-AttentionBlock->[32,6,997]-FC->[32,6,5]組成
- 做完特徵提取後,將標籤與特徵進行拼接後進行輸入,query set的樣本標籤爲全0的vector
- 標籤採用獨熱碼錶示
- loss:採用交叉熵損失函數
訓練
過程與PrototpicalNet相同
實驗結果
Model | 5-way 1-shot Acc. | 5-way 5-shot Acc. | 20-way 1-shot Acc. | 20-way 5-shot Acc. |
---|---|---|---|---|
Reference Paper | 99.07% | 99.78% | 97.64% | 99.36% |
This repo | 98.31% | 99.26% | 93.75% | 97.88% |