在MNIST數據集上準確率99.26%
1、項目簡介
MNIST項目基本上是深度學習初學者的入門項目,本文主要介紹使用keras框架通過構建CNN網絡實現在MNIST數據集上99+的準確率。
2、數據集來源
MNIST手寫數字數據集是深度學習中的經典數據集,該數據集中的數字圖片是由250個不同職業的人手寫繪製的,其中訓練集數據一共60000張圖片,測試集數據一共10000張圖片。每張手寫數字圖片大小都是28*28,每張圖片代表的是從0到9中的每個數字。數據集官網鏈接: THE MNIST DATABASE.該數據集樣例如下所示:
在 FlyAI競賽平臺上 提供了準確率爲99.26%的超詳細代碼實現,同時我們可以通過參加MNIST手寫數字識別練習賽進行進一步學習和優化。下面的代碼實現部分主要該代碼進行講解。
3、代碼實現
3.1、算法流程及實現
算法流程主要分爲以下四個部分進行介紹:
-
數據加載
-
數據增強
-
構建網絡
-
模型訓練
數據加載
在FlyAI的項目中封裝了Dataset類,可以實現對數據的一些基本操作,比如加載批量訓練數據next_train_batch()和校驗數據next_validation_batch()、獲取全量數據get_all_data()、獲取訓練集數據量get_train_length()和獲取校驗集數據量get_validation_length()等。具體使用方法如下:
# 引入Dataset類 from flyai.dataset import Dataset #創建Dataset類的實例 dataset = Dataset(epochs=5, batch=32) # dataset.get_step()返回訓練總次數 for step in range(dataset.get_step()): #獲取一批訓練數據 x_train, y_train = dataset.next_train_batch() # 獲取一批校驗數據 x_val, y_val = dataset.next_validation_batch()
對單張圖片等數據的讀取是在processor.py文件中完成。實現如下:
import numpy as np import cv2 from flyai.processor.base import Base from path import DATA_PATH import os class Processor(Base): # 讀取一張圖片 def input_x(self, image_path): # 獲取圖片路徑 path = os.path.join(DATA_PATH, image_path) # 讀取圖片 img = cv2.imread(path) # 將圖片BGR格式轉換成RGB格式 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 對圖片進行歸一化操作 img = img / 255.0 # 將圖片轉換成 [28, 28, 1] img = img[:, :, 0] img = img.reshape(28, 28, 1) return img # 讀取該圖片對應的標籤 def input_y(self, label): # 對標籤進行onehot化 one_hot_label = np.zeros([10]) # 生成全0矩陣 one_hot_label[label] = 1 # 相應標籤位置置 return one_hot_label
數據增強
數據增強的作用通常是爲了擴充訓練數據量提高模型的泛化能力,同時通過增加了噪聲數據提升模型的魯棒性。在本項目中我們採用了比較簡單的數據增強方法包括旋轉、平移。實現如下:
#數據增強 data_augment = ImageDataGenerator( # 在 [0, 指定角度] 範圍內進行隨機角度旋轉 rotation_range=10, # 當制定一個數時,圖片同時在長寬兩個方向進行同等程度的放縮操作 zoom_range=0.1, # 水平位置平移 width_shift_range=0.1, # 上下位置平移 height_shift_range=0.1, )
爲了展示數據增強的效果,我們對圖像進行了可視化,完整代碼如下:
from keras.preprocessing.image import ImageDataGenerator from flyai.dataset import Dataset import matplotlib.pyplot as plt import numpy as np #數據增強 data_augment = ImageDataGenerator( # 在 [0, 指定角度] 範圍內進行隨機角度旋轉 rotation_range=10, # 當制定一個數時,圖片同時在長寬兩個方向進行同等程度的放縮操作 zoom_range=0.1, # 水平位置平移 width_shift_range=0.1, # 上下位置平移 height_shift_range=0.1, ) dataset = Dataset(epochs=1, batch=4) for _ in range(dataset.get_step()): x_train, y_train = dataset.next_train_batch() # 展示原始圖片 fig = plt.figure() for i in range(4): img = np.concatenate([x_train[i, :], x_train[i, :], x_train[i, :]], axis=-1) sub_img = fig.add_subplot(241 + i) sub_img.imshow(img) # 對每批圖像做數據增強 batch_gen = data_augment.flow(x_train, y=y_train, batch_size=4) x, y = next(batch_gen) # 對增強之後的圖片進行展示 for i in range(4): img = np.concatenate([x[i,:], x[i,:], x[i,:]], axis=-1) sub_img = fig.add_subplot(241 + i + 4) sub_img.imshow(img) plt.show()
可視化結果如圖:
構建網絡
由於手寫數字圖片大小僅爲28*28,圖像寬高比較小不太適合較深的網絡結構。因此我們自己搭建了一個卷積神經網絡,網絡結構如下所示:
# 構建網絡 sqeue = Sequential() # 第一個卷積層,32個卷積核,大小5x5,卷積模式SAME,激活函數relu,輸入張量的大小 sqeue.add(Conv2D(filters=32, kernel_size=(5, 5), padding='Same', activation='relu', input_shape=(28, 28, 1))) sqeue.add(Conv2D(filters=32, kernel_size=(5, 5), padding='Same', activation='relu')) # 池化層,池化核大小2x2 sqeue.add(MaxPool2D(pool_size=(2, 2))) # 隨機丟棄四分之一的網絡連接,防止過擬合 sqeue.add(Dropout(0.25)) sqeue.add(Conv2D(filters=64, kernel_size=(3, 3), padding='Same', activation='relu')) sqeue.add(Conv2D(filters=64, kernel_size=(3, 3), padding='Same', activation='relu')) sqeue.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2))) sqeue.add(Dropout(0.25)) # 全連接層,展開操作, sqeue.add(Flatten()) # 添加隱藏層神經元的數量和激活函數 sqeue.add(Dense(256, activation='relu')) sqeue.add(Dropout(0.25)) # 輸出層 sqeue.add(Dense(10, activation='softmax')) sqeue.summary() sqeue.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
運行summary()方法後輸出的網絡結構如下圖:
keras提供了keras.utils.vis_utils模塊可以對模型進行可視化操作。
from keras.utils import plot_model plot_model(sqeue, show_shapes=True, to_file='model.png')
模型結構圖如下所示:
模型訓練
這裏我們設置了epoch爲5,batch爲32,採用adam優化器來訓練網絡。通過調用FlyAI提供的train_log方法可以在訓練過程中實時的看到訓練集和驗證集的準確率及損失變化曲線。
from flyai.utils.log_helper import train_log history = sqeue.fit(x, y, batch_size=args.BATCH, verbose=0, validation_data=(x_val, y_val)) # 通過調用train_log方法可以實時看到訓練集和驗證集的準確率及損失變化曲線 train_log(train_loss=history.history['loss'][0], train_acc=history.history['accuracy'][0], val_loss=history.history['val_loss'][0], val_acc=history.history['val_accuracy'][0])
訓練集和驗證集的準確率及損失實時變化曲線如圖:
3.2、最終結果
通過使用自定義CNN網絡結構以及數據增強的方法,在epoch爲5,batch爲32使用adam優化器下不斷優化模型參數,最終模型在測試集的準確率達到99.26%。該項目的可運行完整代碼鏈接https://www.flyai.com/download_temp_code?data_id=MNIST。
參考鏈接: