pytorch下搭建網絡訓練並保存模型

最近在學習pytorch,使用mnist數據集,搭建AlexNet訓練並保存模型,將代碼做一記錄。

建立數據集的方法見pytorch建立自己的數據集(以mnist爲例)

搭建網絡的方法見用pytorch搭建AlexNet(微調預訓練模型及手動搭建)

訓練代碼如下:

import torch
import os
from torchvision import transforms
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import DataProcessing as DP
import BuildModel as BM
import torch.nn as nn

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    
    root_path = '/opt/Data/lixiang/ex./pytorch/Alexnet/data/'
    training_path = 'trainingset/'
    test_path = 'testset/'
    model_path = '/opt/Data/lixiang/ex./pytorch/Alexnet/model/'
    
    training_imgfile = training_path + 'trainingset_img.txt'
    training_labelfile = training_path + 'trainingset_label.txt'
    training_imgdata = training_path + 'img/'
    
    test_imgfile = test_path + 'testset_img.txt'
    test_labelfile = test_path + 'testset_label.txt'
    test_imgdata = test_path + 'img/'
    
    #parameter
    batch_size = 128
    epochs = 20
    model_type = 'pre'
    nclasses = 10
    lr = 0.01
    use_gpu = torch.cuda.is_available()
    
    transformations = transforms.Compose(
            [transforms.Scale(256),
             transforms.CenterCrop(224),
             transforms.ToTensor(),
             transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
                    ])
    
    dataset_train = DP.DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata, transformations)
    dataset_test = DP.DataProcessingMnist(root_path, test_imgfile, test_labelfile, test_imgdata, transformations)
    
    num_train, num_test = len(dataset_train), len(dataset_test)
    
    train_loader = DataLoader(dataset_train, batch_size = batch_size, shuffle = True, num_workers = 0)
    test_loader = DataLoader(dataset_test, batch_size = batch_size, shuffle = False, num_workers = 0)
    
    # build model
    model = BM.BuildAlexNet(model_type, nclasses)
    optimizer = optim.SGD(model.parameters(), lr = lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        epoch_loss = 0
        correct_num = 0
        for i, traindata in enumerate(train_loader):
            x_train, y_train = traindata
            if use_gpu:
                x_train, y_train = Variable(x_train.cuda()),Variable(y_train.cuda())
                model = model.cuda()
            else:
                x_train, y_train = Variable(x_train),Variable(y_train)
            y_pre = model(x_train)
            _, label_pre = torch.max(y_pre.data, 1)
            if use_gpu:
                y_pre = y_pre.cuda()
                label_pre = label_pre.cuda()
            model.zero_grad()
            loss = criterion(y_pre, y_train)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.data[0]
            correct_num += torch.sum(label_pre == y_train.data)        
            acc = (torch.sum(label_pre == y_train.data).float()/len(y_train))  
            print('batch loss: {} batch acc: {}'.format(loss.data[0],acc.data[0]))
        print('epoch: {} training loss: {}, training acc: {}'.format(epoch, epoch_loss, correct_num.float()/num_train))
        if (epoch+1) % 5 ==0:
            test_loss = 0
            test_acc_num = 0
            for j, testdata in enumerate(test_loader):
                x_test, y_test = testdata
                if use_gpu:
                    x_test, y_test = Variable(x_test.cuda()), Variable(y_test.cuda())
                else:
                    x_test, y_test = Variable(x_test), Variable(y_test)
                y_pre = model(x_test)
                _, label_pre = torch.max(y_pre.data, 1)
                loss = criterion(y_pre, y_test)
                test_loss += loss.data[0]
                test_acc_num += torch.sum(label_pre == y_test.data)
            print('epoch: {} test loss: {} test acc: {}'.format(epoch, test_loss, test_acc_num.float()/num_test))
    torch.save(model.state_dict(), model_path + 'AlexNet_params.pkl')

主要注意的是一些數據類型的問題,比如label的類型要是LongTensor,損失函數nn.CrossEntropyLoss() 的輸入target要是類別編號而不是one-hot編碼,使用gpu時要把model和輸出y_pre,label_pre移動到gpu上。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章