lenet5 結構 及 pytorch、tensorflow、keras(tf)、paddle實現 mnist手寫數字識別

背景

lenet5網絡源自於Yann LeCun的論文“Gradient-Based Learning Applied to Document Recognition” ,起初被應用於銀行支票的手寫符號識別,經調整後對廣泛應用於手寫數字的識別

 

網絡結構

常用的對minst數據集進行識別的lenet5網絡結構如下

在網上查詢過程中發現對lenet5有 3卷積2連接、2卷積3連接兩種,版本,在原始論文中爲3卷積2池化,因爲沒有填充,卷積後特徵尺寸變爲1*1:

 

input layer: 32*32*1 images

conv1 layer: 5*5*6 conv kernels       28*28*6 output(32 + 0 - 5 + 1 = 28)

pool1 layer: 2*2, 2 maxpool              14*14*6 output( (28 + 0) / 2 )

conv2 layer:  5*5*16 conv kernels    10*10*16 output(14 + 0 - 5 + 1 = 10)

pool2 layer: 2*2, 2 maxpool              5*5*16 output( (10 + 0) / 2 )

conv3 layer:  5*5*120 conv kernels  1*1*120output(5 + 0 - 5 + 1 = 5)

fc1 layer: 84 output(5*5*120 --> 84)

fc2 layer: 10 output(84 --> 10)

 

代碼:

pytorch實現

tensorflow實現

keras實現

paddle實現

 

注:

以上代碼在lenet5的基礎上,實現了:

1)調用框架內置minst數據api讀取數據,

2)進行基本的train、val、inference流程

3)在train時可以輸出各層shape

4)保存最優loss模型,並在結束時輸出最優loss及對應epoch

 

mnist數據集爲28*28*1,在第一次卷積時上下左右  填充2圈,變爲32*32*1

所需測試數據:測試數據pic(內含10張minst圖片)

文件結構:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章