寫在前面的話
前三節我們進行了數據的預處理,介紹了賽題的相關背景:天池實戰-街景字符編碼識別-task1賽題理解
通過Pytorch 批量讀取圖像數據並進行圖像預處理:天池實戰-街景字符編碼識別-task2數據預處理
最後通過CNN模型完成字符識別功能以及介紹如何在Pytorch 下進行模型的建立:天池實戰-街景字符編碼識別-task3模型建立。
這節主要介紹模型的訓練和驗證。
因爲這節Datawhale提供了很詳細的說明文檔,小一也自認爲沒有人家寫的好,這裏就直接轉載過來吧,貼一下源文檔的路徑:Datawhale 零基礎入門CV賽事-Task4 模型訓練與驗證
爲了方便閱讀,部分內容有改動。
模型訓練與驗證
本節將從構建驗證集、模型訓練和驗證、模型保存與加載和模型調參幾個部分講解,在部分小節中將會結合Pytorch代碼進行講解。
1. 學習目標
- 理解驗證集的作用,並使用訓練集和驗證集完成訓練
- 學會使用Pytorch環境下的模型讀取和加載,並瞭解調參流程
2. 構造驗證集
在機器學習模型(特別是深度學習模型)的訓練過程中,模型是非常容易過擬合的。深度學習模型在不斷的訓練過程中訓練誤差會逐漸降低,但測試誤差的走勢則不一定。
在模型的訓練過程中,模型只能利用訓練數據來進行訓練,模型並不能接觸到測試集上的樣本。
因此模型如果將訓練集學的過好,模型就會記住訓練樣本的細節,導致模型在測試集的泛化效果較差,這種現象稱爲過擬合(Overfitting)。與過擬合相對應的是欠擬合(Underfitting),即模型在訓練集上的擬合效果較差。
如圖所示:隨着模型複雜度和模型訓練輪數的增加,CNN模型在訓練集上的誤差會降低,但在測試集上的誤差會逐漸降低,然後逐漸升高,而我們爲了追求的是模型在測試集上的精度越高越好。
導致模型過擬合的情況有很多種原因,其中最爲常見的情況是模型複雜度(Model Complexity )太高,導致模型學習到了訓練數據的方方面面,學習到了一些細枝末節的規律。
解決上述問題最好的解決方法:構建一個與測試集儘可能分佈一致的樣本集(可稱爲驗證集),在訓練過程中不斷驗證模型在驗證集上的精度,並以此控制模型的訓練。
在一般情況下,我們可以自己在本地劃分出一個驗證集出來,進行本地驗證。訓練集、驗證集和測試集分別有不同的作用:
- 訓練集(Train Set):模型用於訓練和調整模型參數;
- 驗證集(Validation Set):用來驗證模型精度和調整模型超參數;
- 測試集(Test Set):驗證模型的泛化能力。
因爲訓練集和驗證集是分開的,所以模型在驗證集上面的精度在一定程度上可以反映模型的泛化能力。在劃分驗證集的時候,需要注意驗證集的分佈應該與測試集儘量保持一致,不然模型在驗證集上的精度就失去了指導意義。
既然驗證集這麼重要,那麼如何劃分本地驗證集呢?
在一些比賽中,賽題方會給定驗證集;如果賽題方沒有給定驗證集,那麼參賽選手就需要從訓練集中拆分一部分得到驗證集。
驗證集的劃分有如下幾種方式:
-
留出法(Hold-Out)
直接將訓練集劃分成兩部分,新的訓練集和驗證集。這種劃分方式的優點是最爲直接簡單;缺點是隻得到了一份驗證集,有可能導致模型在驗證集上過擬合。留出法應用場景是數據量比較大的情況。
-
交叉驗證法(Cross Validation,CV)
將訓練集劃分成K份,將其中的K-1份作爲訓練集,剩餘的1份作爲驗證集,循環K訓練。這種劃分方式是所有的訓練集都是驗證集,最終模型驗證精度是K份平均得到。它的優點是驗證集精度比較可靠,訓練K次可以得到K個有多樣性差異的模型;CV驗證的缺點是需要訓練K次,不適合數據量很大的情況。
-
自助採樣法(BootStrap)
通過有放回的採樣方式得到新的訓練集和驗證集,每次的訓練集和驗證集都是有區別的。這種劃分方式一般適用於數據量較小的情況。
當然這些劃分方法是從數據劃分方式的角度來講的,在現有的數據比賽中一般採用的劃分方法是留出法和交叉驗證法。如果數據量比較大,留出法還是比較合適的。
當然任何的驗證集的劃分得到的驗證集都是要保證訓練集-驗證集-測試集的分佈是一致的,所以如果不管劃分何種的劃分方式都是需要注意的。
這裏的分佈一般指的是與標籤相關的統計分佈,比如在分類任務中“分佈”指的是標籤的類別分佈,訓練集-驗證集-測試集的類別分佈情況應該大體一致;如果標籤是帶有時序信息,則驗證集和測試集的時間間隔應該保持一致。
3. 模型訓練與驗證
在本節我們目標使用Pytorch來完成CNN的訓練和驗證過程,CNN網絡結構與之前的章節中保持一致。我們需要完成的邏輯結構如下:
- 構造訓練集和驗證集;
- 每輪進行訓練和驗證,並根據最優驗證集精度保存模型。
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=10,
shuffle=True,
num_workers=10,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=10,
shuffle=False,
num_workers=10,
)
model = SVHN_Model1()
criterion = nn.CrossEntropyLoss (size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
for epoch in range(20):
print('Epoch: ', epoch)
train(train_loader, model, criterion, optimizer, epoch)
val_loss = validate(val_loader, model, criterion)
# 記錄下驗證集精度
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), './model.pt')
其中每個Epoch的訓練代碼如下:
def train(train_loader, model, criterion, optimizer, epoch):
# 切換模型爲訓練模式
model.train()
for i, (input, target) in enumerate(train_loader):
c0, c1, c2, c3, c4, c5 = model(data[0])
loss = criterion(c0, data[1][:, 0]) + \
criterion(c1, data[1][:, 1]) + \
criterion(c2, data[1][:, 2]) + \
criterion(c3, data[1][:, 3]) + \
criterion(c4, data[1][:, 4]) + \
criterion(c5, data[1][:, 5])
loss /= 6
optimizer.zero_grad()
loss.backward()
optimizer.step()
其中每個Epoch的驗證代碼如下:
def validate(val_loader, model, criterion):
# 切換模型爲預測模型
model.eval()
val_loss = []
# 不記錄模型梯度信息
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
c0, c1, c2, c3, c4, c5 = model(data[0])
loss = criterion(c0, data[1][:, 0]) + \
criterion(c1, data[1][:, 1]) + \
criterion(c2, data[1][:, 2]) + \
criterion(c3, data[1][:, 3]) + \
criterion(c4, data[1][:, 4]) + \
criterion(c5, data[1][:, 5])
loss /= 6
val_loss.append(loss.item())
return np.mean(val_loss)
4. 模型的保存與加載
序列化(保存)與反序列化(讀取)
模型可以保存整個module,或者只保存整個module的參數
- 保存整個module:torch.save(net, path)
- 只保存模型參數:torch.save(net.state_dict(), path)
相應的讀取模型:
- 讀取整個模型:torch.load(fpath)
- 讀取模型參數:net.load_state_dict(torch.load(path))。使用的時候需要先創建一個網絡,然後加載模型參數即可
5. 設置模型斷點
設置模型斷點checkpoint防止在模型訓練過程中意外中斷,在斷點中需要保存模型的數據,優化器的數據和迭代次數等
在每次開始訓練模型的時候都應該判斷是否存在斷點,方便進行模型恢復
6. 模型的微調(finetune)
模型微調的步驟:
- 獲取預訓練模型參數(源任務當中學習到的知識)
- 加載模型(load_state_dict)將學習到的知識放到新的模型
- 修改輸出層, 以適應新的任務
模型微調的訓練方法:
- 固定預訓練的參數(requires_grad=False; lr=0)
- Features Extractor較小學習率(params_group)
總結
本節以深度學習模型的訓練和驗證爲基礎,講解了驗證集劃分方法、模型訓練與驗證、模型保存和加載以及模型調參流程。
需要注意的是模型複雜度是相對的,並不一定模型越複雜越好。在有限設備和有限時間下,需要選擇能夠快速迭代訓練的模型。