Pytorch學習(六) --- 模型訓練的常規train函數flow及其配置

前幾個Pytorch學習博客寫了使用Pytorch的數據讀取、數據增強、數據加載、模型定義,當完成上面幾個步驟,就可以進行模型訓練了。

使用Pytorch進行模型訓練,通常可以將train過程寫成一個函數,簡單的train寫法常規的傳入參數如下:

  • 數據加載器DataLoader
  • 目標模型model
  • 損失函數criterion
  • 優化器optimizer

較爲簡單的train函數可以寫爲如下:

def train(DataLoader, model, criterion, optimizer):
	model.cuda()
	# 指定爲train模式
	model.train()

	for i, (img, target) in tqdm(enumerate(DataLoader)):
		img = img.cuda()
		target = target.cuda()
		# 計算網絡輸出
		output = model(img)
		
		# 計算損失
		loss = criterion(output, target)
		
		
		# 計算梯度和做反向傳播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

那麼,一個較爲完整的使用Pytorch訓練分類任務pipeline可以簡單的表示如下:

1. 定義數據加載
Dataset = torchvision.Dataset(root, transform)

2. 定義模型
model = torchvision.models.xxxx(num_class)

3. 定義數據加載器
DataLoader = torch.util.data.DataLoader(Dataset, batch_size, num_workers)

4. 模型訓練

# 定義優化器
optimizer = 
# 定義損失函數
criterion = 
# 定義學習率調整
scheduler = 
for i in range(epoch_number):
	# 根據epoch調整學習率
	scheduler.step()
	# 調用訓練函數
	train(train_loader, model, criterion, optimizer)

	# 模型保存
	torch.save(model.state_dict(), path)
	

注:以上只是對於使用Pytorch中的API快速做分類任務訓練的一個大框架Pipeline的簡單僞代碼展示,實際編寫code中還有其它的一些util函數,比如計算準確率,訓練到一定階段進行驗證集評估之類。。。
僅供參考!!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章