MAML復現全部細節和經驗教訓(Pytorch)

由於MAML作者提供的源碼比較混亂,而且是由tensorflow寫成。所以我寫了一篇用Pytorch復現MAML的博客:MAML模型無關的元學習代碼完整復現(Pytorch版)。那篇博客中的復現細節已經很詳盡了,但是在omniglot數據集上的準確率只有0.92,考慮到omniglot算是比較簡單的數據集了,因此0.92的準確率實在是太低了。

因此,我後來又對模型和數據的讀取方法進行了一些調整,最近的實驗表明在5-way-1-shot任務上,我的復現準確率已經達到了0.972,算是基本匹配上了作者在論文中給出的準確率區間。

在這篇文章中,我將總結一下我復現MAML時的一些經驗和教訓以及對原來代碼的更改。

1 數據讀取方式

我之前的數據讀取方式是將omniglot中images_backgroud和images_evaluation這兩個文件夾中的數據一次性讀取出來,然後再對數據集進行劃分。

img_list = np.load(os.path.join(root_dir, 'omniglot.npy')) # (1623, 20, 1, 28, 28)
x_train = img_list[:1200]
x_test = img_list[1200:]

這一次我使用通用的數據劃分方法,即:images_backgroud中的數據作爲訓練數據,images_evaluation中的數據作爲測試數據。

img_list_train = np.load(os.path.join(root_dir, 'omniglot_train.npy')) # (964, 20, 1, 28, 28)
img_list_test = np.load(os.path.join(root_dir, 'omniglot_test.npy')) # (659, 20, 1, 28, 28)

x_train = img_list_train
x_test = img_list_test

具體代碼見我的github

2 模型構造

原來的模型卷積層的padding爲2,stride也爲2;我將它們修改爲1之後,實驗結果直接從0.92提升到了0.975。由此可見模型架構的微小調整也會嚴重影響模型的性能。大家平時在做實驗時應該注意一下。

原來的模型架構爲:

#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
            
#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
            
#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2), 
            
#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2), 
            
#             FlattenLayer(),
#             nn.Linear(64,5)
#         )   

修改後的模型架構爲:

#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 1, stride = 1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
            
#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 1, stride = 1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
            
#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 1, stride = 1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2), 
            
#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 1, stride = 1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2), 
            
#             FlattenLayer(),
#             nn.Linear(64,5)
#         )   

3 降低對計算資源的要求

在進行20-way-1-shot的實驗時,發現用原來的代碼將會消耗大量的資源。我修改了一下原來的代碼,在不需要記錄梯度的位置加上"with torch.no_grad()",從而將計算資源的需求降到了原來的1/5.

原來的代碼爲:

            for k in range(1, self.update_step):
                
                y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)
                loss = F.cross_entropy(y_hat, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                tuples = zip(grad, fast_weights) 
                fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
                    
                y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)
                loss_qry = F.cross_entropy(y_hat, y_qry[i])
                loss_list_qry[k+1] += loss_qry
                
                with torch.no_grad():
                    pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)
                    correct = torch.eq(pred_qry, y_qry[i]).sum().item()
                    correct_list[k+1] += correct

修改後的代碼爲:

            for k in range(1, self.update_step):
                
                y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)
                loss = F.cross_entropy(y_hat, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                tuples = zip(grad, fast_weights) 
                fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
                
                if k < self.update_step - 1:
                    with torch.no_grad():        
                        y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)
                        loss_qry = F.cross_entropy(y_hat, y_qry[i])
                        loss_list_qry[k+1] += loss_qry
                else:
                        y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)
                        loss_qry = F.cross_entropy(y_hat, y_qry[i])
                        loss_list_qry[k+1] += loss_qry                    
                
                with torch.no_grad():
                    pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)
                    correct = torch.eq(pred_qry, y_qry[i]).sum().item()
                    correct_list[k+1] += correct

4 關於20-way-1-shot實驗

我在復現這個實驗的過程中,在query集的測試集中的最好結果也只有0.843。但是作者宣稱她取得了0.95的實驗結果,但是作者的源碼中並沒有給出20-way-1-shot的實驗結果或者logs。

我找到了另一個網友(github賬號名:katerkelly)的復現代碼,這個人宣稱他復現出來的結果是0.92。

20-way 1-shot training, best performance 92%

但是我實際運行以及查看了他的代碼後發現,他報告的其實是support集中測試集的結果,而不是query集中測試集的結果。我們都知道在元學習中有support集和query集兩者集合,其中:

  • support集:分爲訓練集和測試集,其中訓練集用於訓練,測試集用於更新參數。
  • query集:分爲訓練集和測試集,其中訓練集用於fine-tune,測試集用於評估元學習模型的效果。

而那位網友報告的是support集中測試集的結果,真正的實驗結果應該是query集中測試集的實驗結果,也就是0.83。

你可以查看那位網友給出的實驗結果展示圖(下圖)。中間那條橙黃色的線是0.92左右,那位網友報告的也是橙黃色這條線的結果,但是實際的實驗結果應該是下面這條紅色的線。也就是0.83左右,跟我得出的實驗結果比較吻合。
網友的實驗結果
有意思的是,MAML作者聲稱她的實驗結果實0.95,而我自己復現的結果中,在support集的測試集上的結果也是0.95-0.96。爲了跑出0.9以上的實驗結果,我已經做了好幾天的實驗了,模型架構和超參數改動了幾十次,最好的結果還是隻有0.843。如果哪位網友能夠復現出0.9以上的實驗結果,麻煩告訴我一下。

5 實驗數據

以下展示在60000輪epoch中,query集的測試集中出現的最好結果:

  1. 20 way 1 shot 4 batch meta_lr = 0.0002, base_lr = 0.1 : acc: 0.84

  2. 20 way 1 shot 8 batch meta_lr = 0.0001, base_lr = 0.1 : acc: 0.835

  3. 20 way 1 shot 8 batch meta_lr = 0.0001, base_lr = 0.1 : acc: 0.843

  4. 20 way 1 shot 8 batch meta_lr = 0.0005, base_lr = 0.3 : acc: 0.79

  5. 20 way 1 shot 8 batch meta_lr = 0.001, base_lr = 0.1 : acc: 0.82

  6. 20 way 1 shot 8 batch meta_lr = 0.001, base_lr = 0.2 : acc: 0.785

  7. 5 way 1 shot 4 batch 10 range meta_lr = 0.001, base_lr = 0.1 : acc: 0.96

  8. 5 way 1 shot 8 batch 10 range meta_lr = 0.001, base_lr = 0.1 : acc: 0.972

  9. 5 way 1 shot 16 batch 10 range meta_lr = 0.001, base_lr = 0.1 : acc: 0.969

  10. 5 way 1 shot 32 batch 10 range meta_lr = 0.001, base_lr = 0.1 : acc: 0.975

自己想要復現的朋友,可以參考一下我的實驗結果,免得繼續做無用功。

6 關於我自己的源碼

你可以在我的github上找到我的全部代碼(miguealanmath)。喜歡的朋友可以點下小星星。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章