pytorch實現lenet5

lenet5 結構 及 pytorch、tensorflow、keras(tf)、paddle實現 mnist手寫數字識別​​​​​​​

環境

python3.6, torch 1.0.1,  torchvision 0.4.0,   torchsummary 1.5.1

代碼

# -*- coding: utf-8 -*- 
# @Time : 2020/1/18 9:38 
# @Author : Zhao HL
# @File : lenet5_torch.py 
import torch, torchvision
from torchvision import transforms
from torch import optim, argmax
from torch.nn import Conv2d, Linear, MaxPool2d, ReLU, Softmax, Module, Sequential, CrossEntropyLoss
from torchsummary import summary
import sys, os
import numpy as np
from PIL import Image
from collections import OrderedDict

# region parameters
# region paths
Data_path = "./data/"
TestData_path = Data_path + 'pic/'
Model_path = 'model/'
Model_file_tf = "model/lenet5_tf.ckpt"
Model_file_keras = "model/lenet5_keras.h5"
Model_file_torch = "model/lenet5_torch.pth"
Model_file_paddle = "model/lenet5_paddle.model"
# endregion

# region image parameter
Img_size = 28
Img_chs = 1
Label_size = 1
Labels_classes = 10
# endregion

# region net parameter
Conv1_kernel_size = 5
Conv1_chs = 6
Conv2_kernel_size = 5
Conv2_chs = 16
Conv3_kernel_size = 5
Conv3_chs = 120
Flatten_size = 120
Fc1_size = 84
Fc2_size = Labels_classes
# endregion

# region hpyerparameter
Learning_rate = 1e-3
Batch_size = 64
Buffer_size = 256
Infer_size = 1
Epochs = 6
Train_num = 60000
Train_batch_num = Train_num // Batch_size
Val_num = 10000
Val_batch_num = Val_num // Batch_size
# endregion
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# endregion

class LeNet5(Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        conv = OrderedDict([
            ('conv1', Conv2d(Img_chs, Conv1_chs, kernel_size=Conv1_kernel_size, padding=2)),
            ('relu1', ReLU()),
            ('pool1', MaxPool2d(kernel_size=2, stride=2)),
            ('conv2', Conv2d(Conv1_chs, Conv2_chs, kernel_size=Conv2_kernel_size, padding=0)),
            ('relu2', ReLU()),
            ('pool2', MaxPool2d(kernel_size=2, stride=2)),
            ('conv3', Conv2d(Conv2_chs, Conv3_chs, kernel_size=Conv3_kernel_size, padding=0)),
            ('relu3', ReLU()),
        ])

        fc = OrderedDict([
            ('fc1', Linear(Flatten_size, Fc1_size)),
            ('relu3', ReLU()),
            ('fc2', Linear(Fc1_size, Fc2_size)),
        ])
        self.conv = Sequential(conv)
        self.fc = Sequential(fc)


    def forward(self, input):
        conv_out = self.conv(input)
        flatten = conv_out.view(conv_out.size()[0], -1)
        output = self.fc(flatten)
        return output


def train(structShow=False):
    train_dataset = torchvision.datasets.MNIST(Data_path, train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=Batch_size, shuffle=True)
    val_dataset = torchvision.datasets.MNIST(Data_path, train=False, transform=transforms.ToTensor())
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=Batch_size, shuffle=True)

    model = LeNet5().to(device)
    if structShow:
        print(summary(model,(1,28,28)))
    if os.path.exists(Model_file_torch):
        model.load_state_dict(torch.load(Model_file_torch))
        print('get model from',Model_file_torch)

    criterion = CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),lr=Learning_rate)

    best_loss = float("inf")
    best_loss_epoch = 0
    for epoch in range(Epochs):
        print('Epoch %d/%d:'%(epoch + 1, Epochs))
        train_sum_loss = 0
        train_sum_acc = 0
        val_sum_loss = 0
        val_sum_acc = 0

        with torch.set_grad_enabled(True):
            for batch_num, (images, labels) in enumerate(train_loader):
                images, labels = images.to(device), labels.to(device)

                optimizer.zero_grad()

                output = model(images)
                loss = criterion(output, labels)
                train_sum_loss += loss.item()

                loss.backward()
                optimizer.step()

                _, predicted = torch.max(output.data, 1)
                correct = (predicted == labels).sum().float()
                acc = correct / labels.size(0)
                train_sum_acc += acc

                process_show(batch_num + 1, len(train_loader), acc, loss,prefix='train:')


        with torch.set_grad_enabled(False):
            for batch_num, (images, labels) in enumerate(val_loader):
                images, labels = images.to(device), labels.to(device)

                output = model(images)
                loss = criterion(output, labels)
                val_sum_loss += loss.item()

                _, predicted = torch.max(output.data, 1)
                correct = (predicted == labels).sum().float()
                acc = correct / labels.size(0)
                val_sum_acc += acc

                process_show(batch_num + 1, len(val_loader), acc, loss, prefix='val:')

        train_sum_loss /= len(train_loader)
        train_sum_acc /= len(train_loader)
        val_sum_loss /= len(val_loader)
        val_sum_acc /= len(val_loader)
        print('average summary:\ntrain acc %.4f, loss %.4f ; val acc %.4f, loss %.4f'
              % (train_sum_acc, train_sum_loss, val_sum_acc, val_sum_loss))
        if val_sum_loss < best_loss:
            print('val_loss improve from %.4f to %.4f, model save to %s ! \n' % (best_loss, val_sum_loss,Model_file_torch))
            best_loss = val_sum_loss
            best_loss_epoch = epoch+1
            torch.save(model.state_dict(), Model_file_torch)
        else:
            print('val_loss do not improve from %.4f \n' % (best_loss))
    print('best loss %.4f at epoch %d \n'%(best_loss,best_loss_epoch))

def inference(infer_path=TestData_path,model_path = Model_file_torch):
    '''
    推理代碼
    :param infer_path: 推理數據
    :param model_path: 模型
    :return:
    '''
    model = LeNet5().to(device)
    model.load_state_dict(torch.load(model_path))
    infer_transform = transforms.ToTensor()

    with torch.no_grad():
        for image_name in os.listdir(infer_path):
            image = load_image(infer_path+image_name)
            img = infer_transform(image)
            img = img.unsqueeze(0).to(device)
            result = model(img)
            _, pre = torch.max(result.data, 1)
            pre = pre.cpu().numpy()
            print("{} predict result {}".format(image_name, pre))


def process_show(num, nums, train_acc, train_loss, prefix='', suffix=''):
    rate = num / nums
    ratenum = int(round(rate, 2) * 100)
    bar = '\r%s batch %3d/%d:train accuracy %.4f, train loss %00.4f [%s%s]%.1f%% %s; ' % (
        prefix, num, nums, train_acc, train_loss, '#' * (ratenum//2), '_' * (50 - ratenum//2), ratenum, suffix)
    sys.stdout.write(bar)
    sys.stdout.flush()
    if num >= nums:
        print()

def load_image(file):
    img = Image.open(file).convert('L')
    img = img.resize((Img_size, Img_size), Image.ANTIALIAS)
    img = np.array(img).reshape(Img_size, Img_size).astype(np.float32)
    img /= 255.0
    return img

if __name__ == '__main__':
    pass
    train(structShow=True)
    inference()

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章