AlexNet介紹(2)---------遷移學習

本文使用pytorch提供的預訓練模型訓練牛津大學提供的flower數據集。該數據集共1360張,分爲17類,每類80張。爲了方便,我把順序打亂後從中拿出1000張作爲訓練集,剩下的作爲驗證集。數據集下載 

 

import torch.nn as nn
from torch.utils import data
from torchvision import transforms
from PIL import Image
from torchvision import models as MD
import torch


class Make_data(data.Dataset):
    def __init__(self, txt, img, tensform):
        self.image = []
        self.txt = txt
        self.img = img
        self.tensform = tensform
        file = open(self.txt)
        lines = file.readlines()
        for line in lines:
            l = line.split("\n")[0]
            self.image.append(l)

    def __getitem__(self, item):
        path = self.image[item].split("  ")
        img_path = path[0]
        img = Image.open(self.img + img_path)
        img = self.tensform(img)
        targe = int(path[1])
        return img, targe

    def __len__(self):
        return len(self.image)


def train(dada_loader):
    model = MD.alexnet(pretrained=False)
    model.load_state_dict(torch.load("../models/alexnet-owt.pth"))
    num_input = model.classifier[6].in_features
    feature_model = list(model.classifier.children())
    feature_model.pop()
    feature_model.append(nn.Linear(num_input, 17))
    model.classifier = nn.Sequential(*feature_model)
    model = model.cuda()
    critersion = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.001)

    for ench in range(200):
        sum = 0
        for i, data in enumerate(dada_loader):
            img, targe = data
            targe = targe.cuda()
            img = img.cuda()
            output = model(img)
            loss = critersion(output, targe)
            opt.zero_grad()
            loss.backward()
            opt.step()
            sum += loss
        print(sum)
        if ench % 20 == 0:
            torch.save(model.state_dict(), "../models/" + str(ench) + ".pkl")

def test(dada_loader):
    model = MD.alexnet(pretrained=False)
    num_input = model.classifier[6].in_features
    feature_model = list(model.classifier.children())
    feature_model.pop()
    feature_model.append(nn.Linear(num_input, 17))
    model.classifier = nn.Sequential(*feature_model)
    #加載訓練過的模型進行測試
    model.load_state_dict(torch.load(""))
    model = model.cuda()
    for i, data in enumerate(dada_loader):
        img, targe = data
        targe = targe.cuda()
        img = img.cuda()
        output = model(img)
        _,pred=torch.max(output.data,1)
        print(torch.sum(pred==targe))
    
if __name__ == '__main__':
    tensform = transforms.Compose([
        transforms.Scale([224, 224]),
        transforms.ToTensor()
    ])
    traindata = Make_data(txt="../Data/train.txt", img="../Data/flower/", tensform=tensform)
    testdata = Make_data(txt="../Data/test.txt", img="../Data/flower/", tensform=tensform)

    train_loader = torch.utils.data.DataLoader(traindata, batch_size=50)
    test_loader = torch.utils.data.DataLoader(testdata, batch_size=50)

    train(train_loader)
    #test(test_loader)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章