A SIMPLE NEURAL ATTENTIVE META-LEARNER

數據集

Omniglot
  • 包含50個字母表的1623個手寫字符,每個字符包含20個樣本
  • 先調整尺寸到28x28,之後通過多次旋轉90度的方式增加字符的種類,一共6492類
  • 劃分
    • 訓練集:82240項 4112類
    • 驗證集:13760項 688類
    • 測試集:33840項 1692類
Mini-ImageNet
  • 從ImageNet中隨機選取100個類,每類包含600個樣本
  • 將尺寸縮放到84x84
  • 包含
    • 訓練集:64類
    • 驗證集:16類
    • 測試集:20類

數據準備

每個iteration包含多個batch,也就是多個eposide;每個eposide包含隨機的classes_per_it個類別,每個類別包含隨機選擇的sample_per_class個樣本組成support set,query set由這些類中的一個隨機類的一個隨機樣本組成。由於這些樣本是作爲一個序列輸入到模型中的,所以最後一個樣本即爲query set,也就是要預測標籤的樣本。輸入時,將一個batch中的所有eposide的樣本拼接起來一起輸入。

模型

將圖像輸入到時序卷積網絡前,先要對圖像做特徵提取

特徵提取
  • Omniglot:使用和PrototpicalNet相同的結構
  • Mini-ImageNet:在PrototpicalNet中,使用的是和Omniglot相同的結構,通道數減少到32,但是這樣淺層的特徵提取網絡沒有充分的利用SNAIL的容量,所以使用了ResNet進行特徵提取
    • 在這裏插入圖片描述
      [84,84,3][42,42,64][21,21,96][10,10,128][5,5,256][5,5,2048][1,1,2048][1,1,384][84,84,3]\rightarrow[42,42,64]\rightarrow[21,21,96]\rightarrow[10,10,128]\rightarrow[5,5,256]\rightarrow[5,5,2048]\rightarrow[1,1,2048]\rightarrow[1,1,384]
時序卷積
  • 時序卷積是通過在時間維度上膨脹的一維卷積生成時序數據的結構,如下圖所示。這種時序卷積是因果的,所以在下一個時間節點生成的值只會被之前時間節點的信息影響,而不受未來信息的影響。相比較於傳統的RNN,它提供了一種更直接,更高帶寬的方式來獲取過去的信息。但是,爲了處理更長的序列,膨脹率通常是指數級增長的,所以需要的卷積層數和序列長度呈對數關係。因此,只能對很久之前的信息進行粗略的訪問,有限的容量和位置依賴性對於元學習方法是不利的,不能充分利用大量的先前的經驗。
    在這裏插入圖片描述
class CasualConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, dilation=1, groups=1, bias=True):
        super(CasualConv1d, self).__init__()
        self.dilation = dilation
        padding = dilation * (kernel_size - 1)
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
                                padding, dilation, groups, bias)

    def forward(self, input):
        # Takes something of shape (N, in_channels, T),
        # returns (N, out_channels, T)
        out = self.conv1d(input)
        return out[:, :, :-self.dilation] # 
  • 在這裏插入圖片描述

  • dilation爲膨脹率(如上圖所示,也就是卷積核元素之間的距離),T爲要處理的序列長度,卷積核大小爲2

class DenseBlock(nn.Module):
    def __init__(self, in_channels, dilation, filters, kernel_size=2):
        super(DenseBlock, self).__init__()
        self.casualconv1 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)
        self.casualconv2 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)

    def forward(self, input):
        # input is dimensions (N, in_channels, T)
        xf = self.casualconv1(input)
        xg = self.casualconv2(input)
        activations = F.tanh(xf) * F.sigmoid(xg) # shape: (N, filters, T)
        return torch.cat((input, activations), dim=1)

在這裏插入圖片描述

  • 爲了提高模型的效果,作者使用了殘差連接和稠密連接。一個denseblock包含一個膨脹率爲R卷積核數爲D的一維因果卷積,使用了geted的激活函數,最後將輸出與輸入進行拼接。
class TCBlock(nn.Module):
    def __init__(self, in_channels, seq_length, filters):
        super(TCBlock, self).__init__()
        self.dense_blocks = nn.ModuleList([DenseBlock(in_channels + i * filters, 2 ** (i+1), filters) for i in range(int(math.ceil(math.log(seq_length))))])

    def forward(self, input):
        # input is dimensions (N, T, in_channels)
        input = torch.transpose(input, 1, 2)
        for block in self.dense_blocks:
            input = block(input)
        return torch.transpose(input, 1, 2)
  • 整個的時序卷積網絡是由一系列的denseblock組成,每個denseblock膨脹率呈指數增加,直到感受野包含整個序列。
注意力模塊

soft attention可以讓模型在可能的無限大的上下文中精確的定位信息,把上下文信息當做無序的鍵值對,通過內容對其進行查找。

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, key_size, value_size):
        super(AttentionBlock, self).__init__()
        self.linear_query = nn.Linear(in_channels, key_size)
        self.linear_keys = nn.Linear(in_channels, key_size)
        self.linear_values = nn.Linear(in_channels, value_size)
        self.sqrt_key_size = math.sqrt(key_size)

    def forward(self, input):
        # input is dim (N, T, in_channels) where N is the batch_size, and T is
        # the sequence length
        mask = np.array([[1 if i>j else 0 for i in range(input.shape[1])] for j in range(input.shape[1])])
        mask = torch.ByteTensor(mask).cuda()

        #import pdb; pdb.set_trace()
        keys = self.linear_keys(input) # shape: (N, T, key_size)
        query = self.linear_query(input) # shape: (N, T, key_size)
        values = self.linear_values(input) # shape: (N, T, value_size)
        temp = torch.bmm(query, torch.transpose(keys, 1, 2)) # shape: (N, T, T)
        temp.data.masked_fill_(mask, -float('inf'))
        temp = F.softmax(temp / self.sqrt_key_size, dim=1) # shape: (N, T, T), broadcasting over any slice [:, x, :], each row of the matrix
        temp = torch.bmm(temp, values) # shape: (N, T, value_size)
        return torch.cat((input, temp), dim=2) # shape: (N, T, in_channels + value_size)

在這裏插入圖片描述

  • 基於self attention,使用鍵值查詢的方式對之前的信息進行訪問,爲了保證在特定的時間節點不能訪問未來的鍵值對,在softmax之前加入了mask,把query與未來的key之間的匹配度設置爲負無窮,最後將輸出與輸入進行拼接。
SNAIL

在這裏插入圖片描述

  • 時序卷積可以在有限的上下文中提供高帶寬的訪問方式,attention可以在很大的上下文中精確地訪問信息,所以將二者結合寄來就得到了SNAIL。在時序卷積產生的上下文中應用causal attention,可以使網絡學習到挑出聚集到的哪些信息,以及如何更好地表示這些信息。SNAIL由兩個卷積和attention交錯組成。
  • 對於N-way,K-shot的問題,輸入序列的長度爲N*K+1
  • 由[192,1,28,28]-encoder->[192,64]-cat->[192,69]->[32,6,69]-AttentionBlock->[32,6,101]-TCBlock->[32,6,357]-AttentionBlock->[32,6,485]-TCBlock->[32,6,741]-AttentionBlock->[32,6,997]-FC->[32,6,5]組成
  • 做完特徵提取後,將標籤與特徵進行拼接後進行輸入,query set的樣本標籤爲全0的vector
  • 標籤採用獨熱碼錶示
  • loss:採用交叉熵損失函數

訓練

過程與PrototpicalNet相同

實驗結果

Model 5-way 1-shot Acc. 5-way 5-shot Acc. 20-way 1-shot Acc. 20-way 5-shot Acc.
Reference Paper 99.07% 99.78% 97.64% 99.36%
This repo 98.31% 99.26% 93.75% 97.88%
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章