task 4-零基礎CV入門-【模型訓練與驗證】

任務:

在瞭解了賽題背景知識、數據集、數據處理、常用模型以及大佬給的baseline之後,現在學習調參訓練模型並保存,再用訓練好的模型對測試集進行驗證。

目標:

  1. 劃分數據集,用訓練集和驗證集進行訓練
  2. 學會pytorch下的模型保存和加載
  3. 瞭解調參流程

準備數據打包

打包:用pytorch封裝的dataloader打包數據,傳入DATAfunc(i.e.datasets),batch_size等參數。dataloader可以根據batch_size將datasets打包成很多個batch_size大小的batch。舉個例子吧,datasets一共有3w個數據,batch_size=256,那麼len(dataloader)=118。因爲3w/256=117.1875,也就是有118個batch。

import torch
test_loader = torch.utils.data.DataLoader(
       	datasets,
        batch_size=256,
        shuffle=False,
        num_workers=2,
)

劃分數據集和驗證集:

訓練集:用於訓練模型,調整模型參數;
驗證集:用於驗證模型精度,調整超參
測試集:驗證模型泛化能力

  • 本次比賽的數據已經幫助我們劃分好了訓練集和驗證集,我們可以將兩個數據集合並,重新進行劃分訓練(注意相應的label也需要按照此方法劃分一致)。
  • 對於數據量較少時,可以採用K折交叉驗證的方法,也就是將訓練數據集劃分成K份,用K-1份去訓練,用1份做驗證,進行K次訓練,最後得到的驗證分數取平均值即可。K一般取值4/5。

模型保存和加載:

# 建立模型
 model = SVHN_Model1()
 criterion = nn.CrossEntropyLoss() # 評估指標,metrics
 optimizer = torch.optim.Adam(model.parameters(), 0.001) # 模型優化器
 best_loss = 1000.0 # loss限制

# 此處省去訓練過程

# 記錄驗證集精度
if val_loss < best_loss:
	best_loss = val_loss
    # print('Find better model in Epoch {0}, saving model.'.format(epoch))
    torch.save(model.state_dict(), './model.pt')  # pytorch保存模型

# 加載保存的最優模型
model.load_state_dict(torch.load('model.pt', map_location='cpu'))
# 接下來就可以拿比賽的測試集去預測,出成績啦

模型訓練:

def train(train_loader, model, criterion, optimizer):
    # 切換模型爲訓練模式
    model.train()
    train_loss = []
    
    for i, (input, target) in enumerate(tqdm(train_loader)):
        target = target.long()  
        c0, c1, c2, c3, c4 = model(input)
        loss = criterion(c0, target[:, 0]) + \
                criterion(c1, target[:, 1]) + \
                criterion(c2, target[:, 2]) + \
                criterion(c3, target[:, 3]) + \
                criterion(c4, target[:, 4])
        
        # loss /= 6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
    return np.mean(train_loss)

def validate(val_loader, model, criterion):
    # 切換模型爲預測模型
    model.eval()
    val_loss = []

    # 不記錄模型梯度信息
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.long() 
            c0, c1, c2, c3, c4 = model(input)
            loss = criterion(c0, target[:, 0]) + \
                    criterion(c1, target[:, 1]) + \
                    criterion(c2, target[:, 2]) + \
                    criterion(c3, target[:, 3]) + \
                    criterion(c4, target[:, 4])
            # loss /= 6
            val_loss.append(loss.item())
    return np.mean(val_loss)

def predict(test_loader, model, tta=10):
    model.eval()
    test_pred_tta = None
    
    # TTA 次數
    for _ in range(tta):
        test_pred = []
    
        with torch.no_grad():
            for i, (input, target) in enumerate(test_loader):
                
                c0, c1, c2, c3, c4 = model(input)
                if use_cuda:
                    output = np.concatenate([
                        c0.data.cpu().numpy(), 
                        c1.data.cpu().numpy(),
                        c2.data.cpu().numpy(), 
                        c3.data.cpu().numpy(),
                        c4.data.cpu().numpy()], axis=1)
                else:
                    output = np.concatenate([
                        c0.data.numpy(), 
                        c1.data.numpy(),
                        c2.data.numpy(), 
                        c3.data.numpy(),
                        c4.data.numpy()], axis=1)
                
                test_pred.append(output)
        
        test_pred = np.vstack(test_pred)
        if test_pred_tta is None:
            test_pred_tta = test_pred
        else:
            test_pred_tta += test_pred
    
    return test_pred_tta

此部分代碼還沒有完全搞明白!!

對baseline進一步詳解:
代碼解釋:

  1. python類中__getitem__方法,可以讓對象實現迭代功能,這樣就可以使用for...in... 來迭代該對象了。
    在用for..in.. 迭代對象時,如果對象沒有實現 __iter__`` __next__迭代器協議,Python的解釋器就會去尋找__getitem__來迭代對象,如果連__getitem__都沒有定義,這解釋器就會報對象不是迭代器的錯誤:TypeError: 'Animal' object is not iterable
    轉自:短短嘟嘟
class People:
    def __init__(self,name):
        self.name = name

    def __getitem__(self, index):
        print('getitem')
        return self.name[index]
 
people = People(['A','B','C'])
for p in people:
    print(p)
getitem
A
getitem
B
getitem
C
getitem
  1. glob.glob(path)
    該函數的path可以使用通配符,比如*就是讀取全部文件,最後返回結果是所有文件的路徑。
import glob
train_path = glob.glob('G:/CVdataset/mchar_train/mchar_train/*.png') # 訓練集的path
train_path
['G:/CVdataset/mchar_train/mchar_train\\000000.png',
 'G:/CVdataset/mchar_train/mchar_train\\000001.png',
 'G:/CVdataset/mchar_train/mchar_train\\000002.png',
 'G:/CVdataset/mchar_train/mchar_train\\000003.png',
 ....]

準備參考的yolo文章:
yolo
yolo
yolo

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章