背景
lenet5網絡源自於Yann LeCun的論文“Gradient-Based Learning Applied to Document Recognition” ,起初被應用於銀行支票的手寫符號識別,經調整後對廣泛應用於手寫數字的識別
網絡結構
常用的對minst數據集進行識別的lenet5網絡結構如下
在網上查詢過程中發現對lenet5有 3卷積2連接、2卷積3連接兩種,版本,在原始論文中爲3卷積2池化,因爲沒有填充,卷積後特徵尺寸變爲1*1:
input layer: 32*32*1 images
conv1 layer: 5*5*6 conv kernels 28*28*6 output(32 + 0 - 5 + 1 = 28)
pool1 layer: 2*2, 2 maxpool 14*14*6 output( (28 + 0) / 2 )
conv2 layer: 5*5*16 conv kernels 10*10*16 output(14 + 0 - 5 + 1 = 10)
pool2 layer: 2*2, 2 maxpool 5*5*16 output( (10 + 0) / 2 )
conv3 layer: 5*5*120 conv kernels 1*1*120output(5 + 0 - 5 + 1 = 5)
fc1 layer: 84 output(5*5*120 --> 84)
fc2 layer: 10 output(84 --> 10)
代碼:
注:
以上代碼在lenet5的基礎上,實現了:
1)調用框架內置minst數據api讀取數據,
2)進行基本的train、val、inference流程
3)在train時可以輸出各層shape
4)保存最優loss模型,並在結束時輸出最優loss及對應epoch
mnist數據集爲28*28*1,在第一次卷積時上下左右 填充2圈,變爲32*32*1
所需測試數據:測試數據pic(內含10張minst圖片)
文件結構: