【深度學習筆記整理-2.1】深度學習的Hello World---MNIST手寫識別

1.MNIST是keras中一個關於手寫識別的數據集,其中有6萬張訓練圖像與1萬張測試圖像。

keras中共有七個內置數據集,

注:CIFAR10與CIFAR100分別爲分類數爲10和分類數100的兩個圖片數據集。

【除內置數據庫外,UCI中還有較多免費資料庫】

2.引入數據集

from keras.datasets import mnist

(train_datas,train_labels),(test_datas,test_labels)=mnist.load_data()

 

train_images.shape  #形狀爲一三維張量:(60000,28,28),即6萬張28*28的圖片
len(train_labels)  #label爲一向量,維度與圖片個數對應,爲60000

3.架構神經網絡(Sequential類型)

from keras import models
from keras import layers

# models是文件名,Sequential對應其中的一個類,這裏是實例化一個對象

network = models.Sequential() 

#網絡第一層需要給輸入的向量形狀,後續則無需再給,全連接網絡需要將輸入的圖片拉直成一個向量輸入,所謂拉直就是將圖片按行排列成一列向量。

network.add(layers.Dense(512,activation = 'relu',input_shape = (28*28,)))

#使用softmax分類

network.add(layers.Dense(10,activation = 'softmax'))

 

4.查看網絡

network.summary()

我們需要估計的參數有407050個,第一層需要估計一個784*512=401408個參數的矩陣,又需要加上一個512維的偏移向量,一共需要401408+512=401920個參數,第二層需要估計一個512*10=5120個參數的矩陣加上10維的偏移向量,共5130個參數,兩層共需401920+5130=407050個參數。

5.編譯網絡

network.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics=['accuracy'])

optimizer指明使用何種優化梯度下降的方法,keras支持的有:SGD(隨機梯度下降),RMSprop(本質上是weighted sum版的RMS),Adagrad(考慮不同梯度參數以往更新的RMS),Adam(考慮動量)等等

loss指明何種損失,用於衡量估計誤差,多分類任務常使用交叉熵。

metrics用於我們觀察模型的好壞,這裏使用分類正確率作爲標準。

6.數據預處理

我們希望我們輸入向量的值維持在0左右,常見的方法有壓縮到0和1之間,或是將其變爲常態分佈,這裏我們採取第一種方法,將圖片的每一個值除以255,然後使用reshape將其拉平,除以255後會出現小數,所以將其類別轉爲float型。

train_images = train_images.reshape((60000,28*28)) # 注意reshape要的輸入是個元組
train_images = train_images.astype('float32')/255

test_images = test_images.reshape((10000,28*28))
test_images = test_images.astype('float32')/255

將類別數字組成的向量轉化爲one-hot向量組成的矩陣

from keras.utils import to_categorical

train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

7.訓練網絡

採用batch大小爲128,5輪來訓練。

network.fit(train_images,train_labels,epoches=5,batch_size=128)

batch_size 代表神經網絡根據多少數據進行一次迭代(iteration:包含一次正向傳播,反向傳播)

一次epoch表示所有數據跑完,一次epoch需要經過的iteration(更新次數)爲所有數據數/batch_size,最終我們更新了epoches*每次epoches更新的iteration數。

8.衡量測試資料

這裏越過驗證集調參的過程,直接對測試集進行衡量好壞。

test_loss, test_acc = network.evaluate(test_images, test_labels)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章