keras入門(二) VGG網絡實現貓狗大戰

上次用keras實現了簡單的線性方程,接下來實現比較經典的CNN網絡-----VGG16,下面顯示的是VGG網絡的結構圖

                                                 

這裏使用vgg-16,經過多個(卷積層,池化層),最後通過三個全連接層變爲一個一維的數據,用softmax生成每個標籤的類別,直接上代碼,數據直接可以下載kaggle的貓狗數據集

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dense, Dropout, Activation, Flatten, BatchNormalization
from keras.optimizers import SGD, Adam
import os
import argparse
import random
import numpy as np
from scipy.misc import imread, imresize
from keras.utils import to_categorical
from keras.datasets import mnist

parser = argparse.ArgumentParser()
parser.add_argument('--train_dir', default='./kaggle/train/')
parser.add_argument('--test_dir', default='./kaggle/test/')
parser.add_argument('--log_dir', default='./')
parser.add_argument('--batch_size', default=8)
parser.add_argument('--gpu', type=int, default=-1)
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
type_list = ['cat', 'dog']

def creat_vgg_16_net():
    model = Sequential()
    model.add(Conv2D(64, (3, 3), input_shape=(224, 224, 1), padding='same', activation='relu', name='conv1_block'))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same',name='conv2_block'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same',name='conv3_block'))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same',name='conv4_block'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
    model.add(Conv2D(256, (3, 3), activation='relu', padding='same',name='conv5_block'))
    model.add(Conv2D(256, (3, 3), activation='relu', padding='same',name='conv6_block'))
    model.add(Conv2D(256, (1, 1), activation='relu', padding='same',name='conv7_block'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same',name='conv8_block'))
    #model.add(Dropout(0.25))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same',name='conv9_block'))
    model.add(Conv2D(512, (1, 1), activation='relu', padding='same',name='conv10_block'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same',name='conv11_block'))
    model.add(Conv2D(512, (3, 3), activation='relu', padding='same',name='conv12_block'))
    model.add(Conv2D(512, (1, 1), activation='relu', padding='same',name='conv13_block'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
    model.add(Flatten())
    model.add(Dense(2048, activation='relu'))
    #model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(4096, activation='relu'))
    #model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(1, activation='sigmoid'))
    #model.add(Dense(2, activation='softmax'))
    return model


def prepare_data():
    file_list = os.listdir(args.train_dir)
    random.shuffle(file_list)
    train_num = int(len(file_list) * 0.8)
    validation_num = len(file_list) - train_num
    train_file_list = file_list[0:train_num]
    validation_file_list = file_list[train_num:]
    return create_generate(train_file_list, args.batch_size, (224, 224)), create_generate(validation_file_list, args.batch_size, (224, 224))


def create_generate(train_file_list, batch_size, input_size):
    while(True):
        random.shuffle(train_file_list)
        image_data = np.zeros((batch_size, input_size[0], input_size[1], 1), dtype='float32')
        label_data = np.zeros((batch_size, 1), dtype='int32')
        for index, file_name in enumerate(train_file_list):
            image = imresize(imread(args.train_dir + file_name, mode='L'), input_size)
            label = file_name.split('.')[0]
            image_data[index % batch_size] = np.reshape(image / 255, (input_size[0], input_size[1], 1))
            label_data[index % batch_size] = type_list.index(label)
            if(0 == (index + 1) % batch_size):
                #label_data = to_categorical(labei_data, 2)
                yield image_data, label_data
                #label_data = keras.utils.to_categorical(label_data, 2)
                image_data = np.zeros((batch_size, input_size[0], input_size[1], 1), dtype='float32')
                label_data = np.zeros((batch_size, 1), dtype='int32')


def train(model):
    sgd = SGD(lr=0.001, decay=1e-8, momentum=0.9, nesterov=True)
    #sgd = Adam(lr=0.001)
    #model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
    model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
    train_generate, validation_genarate = prepare_data()
    model.fit_generator(generator=train_generate, steps_per_epoch=1200, epochs=50, verbose=1,
                        validation_data=validation_genarate, validation_steps=50, max_queue_size=1,
                        shuffle=True)
    model.save_weights(args.log_dir + 'model.h5')


def predict(model):
    file_list = os.listdir(args.test_dir)
    cat_real_count = 0
    dog_real_count = 0
    cat_predict_count = 0
    dog_predict_count = 0
    for index, file_name in enumerate(file_list):
        label = file_name.split('.')[0]
        if 'cat' == label:
            cat_real_count = cat_real_count + 1
        else:
            dog_real_count = dog_real_count + 1
        image = imresize(imread(args.test_dir + file_name, mode='L'), (224, 224))
        label = model.predict(np.reshape(image/255, (1, 224, 224, 1)))
        print(str(label))
        if label <= 0.5:
            cat_predict_count = cat_predict_count + 1
        else:
            dog_predict_count = dog_predict_count + 1
    print(cat_real_count, cat_predict_count, cat_predict_count/ cat_real_count, dog_real_count, dog_predict_count, dog_predict_count / dog_real_count)


if __name__ == '__main__':
    try:
        model = creat_vgg_16_net()
        #train(model)
        model.load_weights('E:/private/deeplearning/model.h5')
        predict(model)
    except Exception  as err:
        print(err)

訓練過程中,驗證集的準確率一直維持在75%,但是測試的時候準確率卻更高,不知道什麼原因,有大神的話可以解釋一下

文中引入1*1的卷積,主要時用於降低計算量,提高運行速度,另外某些網絡1*1的卷積可以用來降維

補充:(1)之前準確率一直在75%,可能是因爲數據量的問題,用25000的貓狗數據進行分類,準確率輕鬆打到90%以上。

           (2)代碼中添加了用softmax進行分類的代碼,詳細看註釋的代碼,感覺二分類還是sigmod的精度高一些

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章