pytorch 寫的一個 lenet的分類網絡,不是百分百還原哈,結構是一樣的, 簡單訓練一下自己的數據集。數據集格式如下,data內存放 自己的數據,每個類別放到一個文件夾中,文件夾名稱爲類別標籤如下圖
1.網絡搭建
import torch
import torch.nn as nn
class Lenet(nn.Module):
def __init__(self, num_classes = 1000):
super(Lenet, self).__init__()
self.conv1 = nn.Conv2d(3,6,5,1,0)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(6, 16, 5, 1, 0)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = nn.Conv2d(16, 120, 5, 1, 0)
self.fc6 = nn.Linear(120, 84)
self.fc7 = nn.Linear(84, 10)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.relu(x)
x = self.pool4(x)
x = self.conv5(x)
x = self.relu(x)
x = x.view(x.size()[0], -1)
x = self.fc6(x)
x = self.relu(x)
x = self.fc7(x)
return x
2、讀取自己的圖像及標籤數據
from torch.utils.data import Dataset
import cv2
import os
from torchvision import transforms as tvtsf
import torch
import numpy as np
class cnndata(Dataset):
def __init__(self, path):
self.path = path
self.img, self.cls, self.class_label= self.get_filname_and_cls()
print("圖像共有:%s 張"%(len(self.img)))
def __getitem__(self, item):
img = cv2.imread(self.img[item], 1)
img_src = cv2.resize(img, (32, 32), cv2.INTER_AREA)
if img_src.shape[2] == 1:
img_src = cv2.cvtColor(img_src, cv2.COLOR_GRAY2BGR)
image = torch.from_numpy(img_src / 255.)
image = image.permute(2, 0, 1).contiguous()
label = torch.from_numpy(np.array(self.cls[item]))
return image, label
def get_filname_and_cls(self):
if not os.path.exists(self.path):
raise Exception("no wenjianjia")
class_name = os.listdir(self.path)
cls = []
imgs = []
class_label = {}
for c, cl in enumerate(class_name):
class_label[c] = cl
filename = os.listdir(os.path.join(self.path, cl))
filename.sort()
for name in filename:
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
imgs.append(os.path.join(self.path, cl, name))
cls.append(c)
return imgs, cls, class_label
def get_classlabel(self):
return self.class_label
def __len__(self):
return len(self.img)
3、 訓練及測試
from get_data import cnndata
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim
import torch
import time
from Lenet import Lenet
def train(batch=16, traindata=None, model= None, epochs=20):
dataloader = DataLoader(traindata, batch_size=int(batch), shuffle=True,num_workers=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
for epoch in range(int(epochs)):
star = time.time()
correct_num = 0
epoch_loss = 0
for img, cls in dataloader:
img= img.type(torch.FloatTensor).cuda()
cls = cls.type(torch.LongTensor).cuda()
optimizer.zero_grad()
out = model(img).cuda()
loss = criterion(out, cls)
epoch_loss += loss.detach().cpu().numpy()
correct_num += torch.eq(cls, out.argmax(dim=1)).sum().item()
loss.backward()
optimizer.step()
print("Epoch :%d , loss%.4f, acc:%.3f, time:%.3f"
%(epoch, round(epoch_loss/len(traindata), 4), round(correct_num/(len(traindata)), 3), time.time()-star))
if correct_num/len(traindata) > 0.8:
torch.save(model.state_dict(), 'lenet.pth')
def test(batch=16, traindata=None, model= None, class_label =None):
dataloader = DataLoader(traindata, batch_size=int(batch), shuffle=True, num_workers=2)
model.load_state_dict(torch.load('lenet.pth'))
model.eval()
with torch.no_grad():
star = time.time()
for img, cls in dataloader:
img = img.type(torch.FloatTensor).cuda()
cls = cls.type(torch.LongTensor).cuda()
out = model(img).cuda()
pre = out.argmax(dim=1)
print('pre is:', pre.detach().cpu().numpy(),'label is:', cls.detach().cpu().numpy())
def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight.data)
m.bias.data.zero_()
if __name__ == '__main__':
batch = 2
traindata = cnndata(path="./data")
model = Lenet(num_classes=10).cuda()
model.apply(weights_init)
class_label =traindata.get_classlabel()
print('index and labels:', class_label)
train(batch=batch, traindata=traindata, model=model)
# test(batch=batch, traindata=traindata, model=model, class_label=class_label)
訓練結果:
預測結果: