全連接自制數據集實現貓狗識別

網絡

import torch.nn as nn
class weNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(100*100*3,512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,2),
            nn.Softmax(),
        )
    def forward(self, x):
        output = self.layers(x)
        return output

數據集

import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
class dataSet(data.Dataset):
    def __init__(self,path):
        super().__init__()
        self.path = path
        self.dataset = []
        self.dataset.extend(os.listdir(path))
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, index):
        lable = torch.Tensor(np.array([int(self.dataset[index][0])])) #拿出圖像的類別標籤
        img_path = os.path.join(self.path,self.dataset[index])  #拼圖片路徑
        img = Image.open(img_path)                               #打開圖片的流
        img_data = torch.Tensor(np.array(img)/255-0.5)          #歸一化處理
        return img_data,lable
if __name__ == '__main__':
    mydata = dataSet("img")
    x = mydata[6200][1].numpy()
    img_data = np.array((x + 0.5) * 255, dtype=np.int8)
    img = Image.fromarray(img_data,"RGB")
    img.show()

測試

import torch.nn as nn
from torch.utils import data
import torch
import numpy as np

from weData import dataSet
from weNet import weNet

if __name__ == '__main__':
    my_data = dataSet("img")
    train_data = data.DataLoader(dataset=my_data, batch_size=100, shuffle=True)
    net = weNet()
    optimzer = torch.optim.Adam(net.parameters())
    loss_fun = nn.MSELoss()
    for epoch in range(1):
        for i,(x,y) in enumerate(train_data):
            print("x:{}".format(x.size()))
            print("y:{}".format(y.size()))
            x = x.view(x.size(0), -1)
            print("x變換後:{}".format(x.size()))
            output = net(x)
            y = y.long()
            y =torch.zeros(y.size()[0],2).scatter_(1,y.view(-1,1),1)
            loss = loss_fun(y,output)
            optimzer.zero_grad()
            loss.backward()
            optimzer.step()
            if i%10 == 0:
                print(loss.item())
            out = torch.argmax(output,dim=1)
            y = torch.argmax(y,dim=1)
            acc = np.mean(np.array(out==y,dtype=np.float32))
            print(acc)

總結

1.在求損失時要保證標籤的維度和前項計算得出的結果的維度保持一致,可利用view函數進行變換

2.製作的數據集的維度要和測試取的數據要相互對應

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章