PyTorch:Digit Recognizer比賽後續-數據增強

背景

上次利用PyTorch做了一個模型以後,訓練的結果只有94%左右,並不是特別好。本次進行了一些改進:

  1. 增加訓練次數
  2. 添加數據增強

代碼鏈接:
https://www.kaggle.com/yannnnnnnnnnnn/kernel5d66c76231?scriptVersionId=28235847 version9
結果:
在這裏插入圖片描述


方法

1、增加訓練次數

這個沒啥好說的,直接修改epoch=30即可。

2、增加數據增強

增加數據增強,首先要繼承Dataset數據結構,構建MNIST_Data,代碼如下:
其中已經定義了一個放射變化的數據增強方法,包括旋轉和平移。
旋轉爲繞中心-30°~+30°,平移爲-0.1*width~+0.1*weidth-0.1*height~+0.1*height

class MNIST_data(Dataset):
    """MNIST dtaa set"""
    
    def __init__(self, 
                 data, 
                 transform = transforms.Compose([transforms.ToPILImage(),
                                                 transforms.RandomAffine(30,(0.1,0.1)),
                                                 transforms.ToTensor()
                                                ])
                ):
        
        if len(data) == 1:
            # test data
            self.X = data[0].reshape(-1,28,28)
            self.y = None
        else:
            # training data
            self.X = data[0].reshape(-1,28,28)
            self.y = data[1].astype(np.long)
            
        self.transform = transform
    
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if self.y is not None:
            return self.transform(self.X[idx]), self.y[idx]
        else:
            return self.transform(self.X[idx])

結論

在原始模型的基礎上,僅僅添加了數據增強就大大提高了預測結果(當然第一版本的代碼還有一個bug,忘記寫model.train(),沒有激活dropout)。後續考慮加上:

  • BatchNorm
  • 變化的學習率
發佈了151 篇原創文章 · 獲贊 160 · 訪問量 44萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章