pytorch實現GoogLeNet-InceptionV3

環境

python3.6, torch 1.0.1,  torchvision 0.4.0,   torchsummary 1.5.1

 

代碼

# -*- coding: utf-8 -*- 
# @Time : 2020/2/21 13:53 
# @Author : Zhao HL
# @File : InceptionV3-torch.py
import torch, torchvision
from torchvision import transforms
from torch import optim, argmax
from torch.nn import Conv2d, Linear, MaxPool2d, MaxPool2d,BatchNorm2d, ReLU, Softmax, Dropout, Module, Sequential, CrossEntropyLoss
from torchsummary import summary
import sys, os
import numpy as np
from PIL import Image
import pandas as pd
from collections import OrderedDict
from my_utils import process_show, draw_loss_acc, dataInfo_show, dataset_divide

# region parameters
# region paths
Data_path = "./data/"
Data_csv_path = "./data/split.txt"
Model_path = 'model/'
Model_file_tf = "model/InceptionV1_tf.ckpt"
Model_file_keras = "model/InceptionV1_keras.h5"
Model_file_torch = "model/InceptionV1_torch.pth"
Model_file_paddle = "model/InceptionV1_paddle.model"
# endregion

# region image parameter
Img_size = 299
Img_chs = 3
Label_size = 1
Label_class = ['agricultural',
               'airplane',
               'baseballdiamond',
               'beach',
               'buildings',
               'chaparral',
               'denseresidential',
               'forest',
               'freeway',
               'golfcourse',
               'harbor',
               'intersection',
               'mediumresidential',
               'mobilehomepark',
               'overpass',
               'parkinglot',
               'river',
               'runway',
               'sparseresidential',
               'storagetanks',
               'tenniscourt']
Labels_nums = len(Label_class)
# endregion

# region net parameter
Conv1_chs = 32
Conv2_chs = 32
Conv3_chs = 64
Conv4_chs = 80
Conv5_chs = 192
Conv6_chs = 288
Icp3a_size = (288, 64, 64, 96, 48, 64, 64)
Icp3b_size = (288, 64, 64, 96, 48, 64, 64)
Icp3c_size = (288, 0, 192, 384, 64, 96, 288)
Icp5a_size = (768, 192, 160, 192, 160, 192, 192)
Icp5b_size = (768, 192, 160, 192, 160, 192, 192)
Icp5c_size = (768, 192, 160, 192, 160, 192, 192)
Icp5d_size = (768, 192, 160, 192, 160, 192, 192)
Icp5e_size = (768, 0, 192, 320, 192, 192, 768)
Icp2a_size = (1280, 320, 384, 384, 448, 384, 192)
Icp2b_size = (2048, 320, 384, 384, 448, 384, 192)
# endregion

# region hpyerparameter
Learning_rate = 0.045
Batch_size = 8
Buffer_size = 256
Infer_size = 1
Epochs = 20
Train_num = 1470
Train_batch_num = Train_num // Batch_size
Val_num = 210
Val_batch_num = Val_num // Batch_size
Test_num = 420
Test_batch_num = Test_num // Batch_size
# endregion
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# endregion

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root_path, files_list=None, transform=None):
        self.root_path = root_path
        self.transform = transform
        self.files_list = files_list if files_list else os.listdir(root_path)
        self.size = len(files_list)

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root_path, self.files_list[index]))
        if self.transform:
            img = self.transform(img)
        label_str = os.path.basename(self.files_list[index])[:-6]
        label = Label_class.index(label_str)
        return img, label


class InceptionV3_ModelA(Module):
    def __init__(self, model_size,downsample=False):
        super(InceptionV3_ModelA, self).__init__()
        input_chs, con1_chs, con31_chs, con3_chs, con51_chs, con5_chs, pool_chs = model_size
        stride = 2 if downsample else 1
        padding = 0 if downsample else 1
        if downsample == False:
            self.conv1 = Sequential(
                Conv2d(input_chs, con1_chs, kernel_size=1),
                ReLU(),
                BatchNorm2d(con1_chs),
            )
        self.conv3 = Sequential(
            Conv2d(input_chs, con31_chs, kernel_size=1),
            ReLU(),
            BatchNorm2d(con31_chs),
            Conv2d(con31_chs, con3_chs, kernel_size=3,stride=stride,padding=padding),
            ReLU(),
            BatchNorm2d(con3_chs),
        )
        self.conv5 = Sequential(
            Conv2d(input_chs, con51_chs, kernel_size=1),
            ReLU(),
            BatchNorm2d(con51_chs),
            Conv2d(con51_chs, con5_chs, kernel_size=3,padding=1),
            ReLU(),
            BatchNorm2d(con5_chs),
            Conv2d(con5_chs, con5_chs, kernel_size=3, stride=stride,padding=padding),
            ReLU(),
            BatchNorm2d(con5_chs),
        )
        self.pool1 = Sequential(
            MaxPool2d(kernel_size=3,stride=stride,padding=padding),
            Conv2d(input_chs, pool_chs, kernel_size=1),
            ReLU(),
            BatchNorm2d(pool_chs),
        )
        self.downsample = downsample

    def forward(self, input):
        out0 = self.pool1(input)
        out3 = self.conv3(input)
        out5 = self.conv5(input)
        if self.downsample:
            result = torch.cat([out0, out3, out5], dim=1)
        else:
            out1 = self.conv1(input)
            result = torch.cat([out0, out1, out3, out5], dim=1)
        return result

class InceptionV3_ModelB(Module):
    def __init__(self, model_size,downsample=False):
        super(InceptionV3_ModelB, self).__init__()
        input_chs, con1_chs, con31_chs, con3_chs, con51_chs, con5_chs, pool_chs = model_size
        stride = 2 if downsample else 1
        padding = 0 if downsample else 1
        self.pool1 = Sequential(
            MaxPool2d(kernel_size=3, stride=stride, padding=padding),
            Conv2d(input_chs, pool_chs, kernel_size=1),
            ReLU(),
            BatchNorm2d(pool_chs),
        )
        if downsample:
            self.conv3 = Sequential(
                Conv2d(input_chs, con31_chs, kernel_size=1),
                ReLU(),
                BatchNorm2d(con31_chs),
                Conv2d(con31_chs, con3_chs, kernel_size=3, stride=stride, padding=padding),
                ReLU(),
                BatchNorm2d(con3_chs),
            )
            self.conv5 = Sequential(
                Conv2d(input_chs, con51_chs, kernel_size=1),
                ReLU(),
                BatchNorm2d(con51_chs),
                Conv2d(con51_chs, con51_chs, kernel_size=(1, 7),padding=(0,3)),
                ReLU(),
                BatchNorm2d(con51_chs),
                Conv2d(con51_chs, con51_chs, kernel_size=(7, 1),padding=(3,0)),
                ReLU(),
                BatchNorm2d(con51_chs),
                Conv2d(con51_chs, con5_chs, kernel_size=3, stride=stride, padding=padding),
                ReLU(),
                BatchNorm2d(con5_chs),
            )
        else:
            self.conv1 = Sequential(
                Conv2d(input_chs, con1_chs, kernel_size=1),
                ReLU(),
                BatchNorm2d(con1_chs),
            )
            self.conv3 = Sequential(
                Conv2d(input_chs, con31_chs, kernel_size=1),
                ReLU(),
                BatchNorm2d(con31_chs),
                Conv2d(con31_chs, con31_chs, kernel_size=(1,7),padding=(0,3)),
                ReLU(),
                BatchNorm2d(con31_chs),
                Conv2d(con31_chs, con3_chs, kernel_size=(7,1),padding=(3,0)),
                ReLU(),
                BatchNorm2d(con3_chs),
            )
            self.conv5 = Sequential(
                Conv2d(input_chs, con51_chs, kernel_size=1),
                ReLU(),
                BatchNorm2d(con51_chs),
                Conv2d(con51_chs, con51_chs, kernel_size=(1,7),padding=(0,3)),
                ReLU(),
                BatchNorm2d(con51_chs),
                Conv2d(con51_chs, con51_chs, kernel_size=(7,1),padding=(3,0)),
                ReLU(),
                BatchNorm2d(con51_chs),
                Conv2d(con51_chs, con51_chs, kernel_size=(1, 7),padding=(0,3)),
                ReLU(),
                BatchNorm2d(con51_chs),
                Conv2d(con51_chs, con5_chs, kernel_size=(7, 1),padding=(3,0)),
                ReLU(),
                BatchNorm2d(con5_chs),
            )
        self.downsample = downsample

    def forward(self, input):
        out0 = self.pool1(input)
        out3 = self.conv3(input)
        out5 = self.conv5(input)
        if self.downsample:
            result = torch.cat([out0, out3, out5], dim=1)
        else:
            out1 = self.conv1(input)
            result = torch.cat([out0, out1, out3, out5], dim=1)
        return result

class InceptionV3_ModelC(Module):
    def __init__(self, model_size):
        super(InceptionV3_ModelC, self).__init__()
        input_chs, con1_chs, con31_chs, con3_chs, con51_chs, con5_chs, pool_chs = model_size
        self.pool1 = Sequential(
            MaxPool2d(kernel_size=3, stride=1,padding=1),
            Conv2d(input_chs, pool_chs, kernel_size=1),
            ReLU(),
            BatchNorm2d(pool_chs),
        )
        self.conv1 = Sequential(
            Conv2d(input_chs, con1_chs, kernel_size=1),
            ReLU(),
            BatchNorm2d(con1_chs),
        )
        self.conv30 = Sequential(
            Conv2d(input_chs, con31_chs, kernel_size=1),
            ReLU(),
            BatchNorm2d(con31_chs),
        )
        self.conv31 = Sequential(
            Conv2d(con31_chs, con3_chs, kernel_size=(3, 1),padding=(1,0)),
            ReLU(),
            BatchNorm2d(con3_chs),
        )
        self.conv32 = Sequential(
            Conv2d(con31_chs, con3_chs, kernel_size=(1, 3),padding=(0,1)),
            ReLU(),
            BatchNorm2d(con3_chs),
        )
        self.conv50 = Sequential(
            Conv2d(input_chs, con51_chs, kernel_size=1),
            ReLU(),
            BatchNorm2d(con51_chs),
            Conv2d(con51_chs, con5_chs, kernel_size=3,padding=1),
            ReLU(),
            BatchNorm2d(con5_chs),
        )
        self.conv51 = Sequential(
            Conv2d(con5_chs, con5_chs, kernel_size=(3, 1),padding=(1,0)),
            ReLU(),
            BatchNorm2d(con5_chs),
        )
        self.conv52 = Sequential(
            Conv2d(con5_chs, con5_chs, kernel_size=(1, 3),padding=(0,1)),
            ReLU(),
            BatchNorm2d(con5_chs),
        )

    def forward(self, input):
        out0 = self.pool1(input)
        out1 = self.conv1(input)
        out30 = self.conv30(input)
        out31 = self.conv31(out30)
        out32 = self.conv32(out30)
        out50 = self.conv50(input)
        out51 = self.conv51(out50)
        out52 = self.conv52(out50)

        # result = torch.cat([out0, out1, out31, out32, out51, out52], dim=1)
        result = torch.cat([out0, out1, 
                            torch.cat([out31, out32], dim=1), 
                            torch.cat([out51, out52], dim=1)], dim=1)
        return result

class InceptionV3_Out(Module):
    def __init__(self):
        super(InceptionV3_Out, self).__init__()
        self.conv = Sequential(
            MaxPool2d(kernel_size=5, stride=3),
            Conv2d(768, 128, kernel_size=1),
            ReLU(),
            BatchNorm2d(128),
            Conv2d(128, 768, kernel_size=5),
            ReLU(),
            BatchNorm2d(768),
        )
        self.out = Sequential(
            Linear(768, Labels_nums),
            Softmax(),
            # ReLU(),
        )
    def forward(self, input):
        x = self.conv(input)
        x = x.view(x.size(0),-1)
        x = self.out(x)
        return x



class InceptionV3(Module):
    def __init__(self):
        super(InceptionV3, self).__init__()

        self.conv = Sequential(
            Conv2d(Img_chs, Conv1_chs, kernel_size=3, stride=2, ),
            ReLU(),
            BatchNorm2d(Conv1_chs),
            Conv2d(Conv1_chs, Conv2_chs, kernel_size=3, stride=1, ),
            ReLU(),
            BatchNorm2d(Conv2_chs),
            Conv2d(Conv2_chs, Conv3_chs, kernel_size=3, stride=1, padding=1),
            ReLU(),
            BatchNorm2d(Conv3_chs),
            MaxPool2d(kernel_size=3, stride=2),
            Conv2d(Conv3_chs, Conv4_chs, kernel_size=3, stride=1),
            ReLU(),
            BatchNorm2d(Conv4_chs),
            Conv2d(Conv4_chs, Conv5_chs, kernel_size=3, stride=2),
            ReLU(),
            BatchNorm2d(Conv5_chs),
            Conv2d(Conv5_chs, Conv6_chs, kernel_size=3, stride=1, padding=1),
            ReLU(),
            BatchNorm2d(Conv6_chs),
        )
        self.inception3a = InceptionV3_ModelA(Icp3a_size)
        self.inception3b = InceptionV3_ModelA(Icp3b_size)
        self.inception3c = InceptionV3_ModelA(Icp3c_size,downsample=True)

        self.inception5a = InceptionV3_ModelB(Icp5a_size)
        self.inception5b = InceptionV3_ModelB(Icp5b_size)
        self.inception5c = InceptionV3_ModelB(Icp5c_size)
        self.inception5d = InceptionV3_ModelB(Icp5d_size)
        self.auxout = InceptionV3_Out()
        self.inception5e = InceptionV3_ModelB(Icp5e_size,downsample=True)

        self.inception2a = InceptionV3_ModelC(Icp2a_size)
        self.inception2b = InceptionV3_ModelC(Icp2b_size)

        self.pool = Sequential(
            MaxPool2d(kernel_size=8, stride=1),
        )
        self.out = Sequential(
            Dropout(p=0.4),
            Linear(2048, Labels_nums),
            Softmax(),
            # ReLU(),
        )

    def forward(self, input):
        x = self.conv(input)
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.inception3c(x)

        x = self.inception5a(x)

        x = self.inception5b(x)
        x = self.inception5c(x)
        x = self.inception5d(x)
        if self.training == True:
            auxout = self.auxout(x)
        x = self.inception5e(x)


        x = self.inception2a(x)
        x = self.inception2b(x)

        x = self.pool(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        if self.training == True:
            return output,auxout
        return output


def train(structShow=False):
    transform = transforms.Compose([
        transforms.Resize((Img_size, Img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    df = pd.read_csv(Data_csv_path, header=0, index_col=0)
    train_list = df[df['split'] == 'train']['filename'].tolist()
    val_list = df[df['split'] == 'val']['filename'].tolist()
    train_dataset = MyDataset(Data_path, files_list=train_list, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=Batch_size, shuffle=True)
    val_dataset = MyDataset(Data_path, files_list=val_list, transform=transform)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=Batch_size, shuffle=True)

    model = InceptionV3().to(device)
    model.train()
    if structShow:
        print(summary(model, (3, 299, 299)))
    # if os.path.exists(Model_file_torch):
    #     model.load_state_dict(torch.load(Model_file_torch))
    #     print('get model from',Model_file_torch)

    criterion = CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=Learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer,2,gamma=0.94)

    train_loss = np.ones(Epochs)
    train_acc = np.ones(Epochs)
    val_loss = np.ones(Epochs)
    val_acc = np.ones(Epochs)
    best_loss = float("inf")
    best_loss_epoch = 0
    for epoch in range(Epochs):
        print('Epoch %d/%d:' % (epoch + 1, Epochs))
        train_sum_loss = 0
        train_sum_acc = 0
        val_sum_loss = 0
        val_sum_acc = 0
        model.train()
        scheduler.step()
        with torch.set_grad_enabled(True):
            for batch_num, (images, labels) in enumerate(train_loader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()

                output,output1 = model(images)
                loss = criterion(output, labels)
                loss1 = criterion(output1, labels)
                total_loss = 0.7 * loss + 0.3 * loss1
                train_sum_loss += total_loss.item()

                total_loss.backward()
                optimizer.step()

                _, predicted = torch.max(output.data, 1)
                correct = (predicted == labels).sum().float()
                acc = correct / labels.size(0)
                train_sum_acc += acc

                process_show(batch_num + 1, len(train_loader), acc, loss, prefix='train:')

        model.eval()
        with torch.set_grad_enabled(False):
            for batch_num, (images, labels) in enumerate(val_loader):
                images, labels = images.to(device), labels.to(device)

                output = model(images)
                loss = criterion(output, labels)
                val_sum_loss += loss.item()

                _, predicted = torch.max(output.data, 1)
                correct = (predicted == labels).sum().float()
                acc = correct / labels.size(0)
                val_sum_acc += acc

                process_show(batch_num + 1, len(val_loader), acc, loss, prefix='val:')

        train_sum_loss /= len(train_loader)
        train_sum_acc /= len(train_loader)
        val_sum_loss /= len(val_loader)
        val_sum_acc /= len(val_loader)

        train_loss[epoch] = train_sum_loss
        train_acc[epoch] = train_sum_acc
        val_loss[epoch] = val_sum_loss
        val_acc[epoch] = val_sum_acc

        print('average summary:\ntrain acc %.4f, loss %.4f ; val acc %.4f, loss %.4f'
              % (train_sum_acc, train_sum_loss, val_sum_acc, val_sum_loss))
        if val_sum_loss < best_loss:
            print('val_loss improve from %.4f to %.4f, model save to %s ! \n' % (
                best_loss, val_sum_loss, Model_file_torch))
            best_loss = val_sum_loss
            best_loss_epoch = epoch + 1
            torch.save(model.state_dict(), Model_file_torch)
        else:
            print('val_loss do not improve from %.4f \n' % (best_loss))
    print('best loss %.4f at epoch %d \n' % (best_loss, best_loss_epoch))
    draw_loss_acc(train_loss, train_acc, 'train')
    draw_loss_acc(val_loss, val_acc, 'val')


if __name__ == '__main__':
    pass
    # dataset_divide(r'E:\_Python\01_deeplearning\04_GoogLeNet\Inception1\data\split.txt')
    train(structShow=True)

my_utils.py

# -*- coding: utf-8 -*- 
# @Time : 2020/1/21 11:39 
# @Author : Zhao HL
# @File : my_utils.py
import sys,os,random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
def process_show(num, nums, train_acc, train_loss, prefix='', suffix=''):
    rate = num / nums
    ratenum = int(round(rate, 2) * 100)
    bar = '\r%s batch %3d/%d:train accuracy %.4f, train loss %00.4f [%s%s]%.1f%% %s; ' % (
        prefix, num, nums, train_acc, train_loss, '#' * (ratenum//2), '_' * (50 - ratenum//2), ratenum, suffix)
    sys.stdout.write(bar)
    sys.stdout.flush()
    if num >= nums:
        print()

def dataInfo_show(data_path,csv_pth,cls_dic_path,shapesShow=True,classesShow=True):
    cls_dict = get_cls_dic(cls_dic_path)
    if classesShow:
        print('\n'+'*'*50)
        df = pd.read_csv(csv_pth)
        labels = df['label'].unique()
        label_cls = {label:cls_dict[label] for label in labels}
        print(label_cls)
        cls_count = df['label'].value_counts()
        cls_count = {cls_dict[k]:v for k,v in cls_count.items()}
        for k,v in cls_count.items():
            print(k,v)

    if shapesShow:
        print('\n'+'*'*50)
        shapes = []
        for filename in os.listdir(data_path):
            img = Image.open(os.path.join(data_path, filename))
            img = np.array(img)
            shapes.append(img.shape)
        shapes = pd.Series(shapes)
        print(shapes.value_counts())

def get_cls_dic(cls_dic_path):
    # 讀取類標籤字典,只取第一個逗號前的信息
    cls_df = pd.read_csv(cls_dic_path)
    cls_df['cls'] = cls_df['info'].apply(lambda x:x[:9]).tolist()
    cls_df['label'] = cls_df['info'].apply(lambda x: x[10:]).tolist()
    cls_df = cls_df.drop(columns=['info','other'])

    cls_dict = cls_df.set_index('cls').T.to_dict('list')
    cls_dict = {k:v[0] for k,v in cls_dict.items()}
    return cls_dict

def dataset_divide(csv_pth):
    cls_df = pd.read_csv(csv_pth, header=0,index_col=0)
    cls_df.insert(1,'split',None)
    filenames = list(cls_df['filename'])
    random.shuffle(filenames)
    train_num,train_val_num = int(len(filenames)*0.7),int(len(filenames)*0.8)
    train_names = filenames[:train_num]
    val_names = filenames[train_num:train_val_num]
    test_names = filenames[train_val_num:]
    cls_df.loc[cls_df['filename'].isin(train_names),'split'] = 'train'
    cls_df.loc[cls_df['filename'].isin(val_names), 'split'] = 'val'
    cls_df.loc[cls_df['filename'].isin(test_names), 'split'] = 'test'
    cls_df.to_csv(csv_pth)

def draw_loss_acc(loss,acc,type='',save_path=None):
    assert len(acc) == len(loss)
    x = [epoch for epoch in range(len(acc))]
    plt.subplot(2, 1, 1)
    plt.plot(x, acc, 'o-')
    plt.title(type+'  accuracy vs. epoches')
    plt.ylabel('accuracy')
    plt.subplot(2, 1, 2)
    plt.plot(x, loss, '.-')
    plt.xlabel(type+'  loss vs. epoches')
    plt.ylabel('loss')
    plt.show()
    if save_path:
        plt.savefig(os.path.join(save_path,type+"_acc_loss.png"))


if __name__ == '__main__':
    pass

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章