實戰Omniglot數據集識別(手寫數字MNIST升級版)

實戰Omniglot數據集識別(手寫數字MNIST升級版)


最近模式識別老師佈置了一個大作業:手寫字符識別
這裏用到的數據集是Omniglot,如下:

在這裏插入圖片描述
這個數據集可謂是手寫數字識別的爸爸呀,手寫數字識別作爲識別領域最簡單的任務(大概吧),Omniglot相比它的難點就在於:Omniglot具有1623個類別,但每個類別只有20張圖片。
今天正好週末,我們就來打(虐)發(待)一下時(電)間(腦)吧。

經典方法

在做什麼任務之前,我們都應該想想經典方法可不可以實現。
用於分類,我們可以使用SVM或者樸素貝葉斯分類等。
我嘗試了SVM,速度真是慘不忍睹,之後我嘗試了樸素貝葉斯分類法,準確率大概是0.13左右。

簡單方法

最簡單——無腦全連接

我們先從最簡單的方法開始,直接把圖片展成一個向量,然後送入全連接。
網絡結構如下:

784
1024
2048
2048
1623

每一層均使用batchnorm,ReLU激活,並使用DropOut,概率爲0.5。
這方法簡單但準確率並不高,甚至和樸素貝葉斯分類方法準確率不相上下,訓練過程如下:
在這裏插入圖片描述
測試集準確率在0.18左右浮動,就連訓練集準確率也沒有突破0.2。
並不想浪費時間在調整這樣的無腦網絡上,所以我沒有繼續想辦法優化,僅作爲一個嘗試,接下來開始卷積網絡。

簡單卷積網絡

首先作爲嘗試,我先選擇簡單的卷積網絡進行試驗,設計網絡結構如下:

(括號裏表示in_channel,out_channel,kernal_size,stride)
Conv2d(1,64,3,2)
Conv2d(64,128,3,2)
Conv2d(128,256,3,1)
Conv2d(256,512,3,1)
View(1,512*3*3)
Full_connect(512*3*3,2048)
Full_connect(2048,1623)
層間均用ReLU激活,全連接中添加batchnorm層(由於通過嘗試發現卷積層間添加batchnorm時會導致準確率降低,所以在這裏不添加)

訓練設置weight_decay爲1e-4,初始學習率爲1e-4,並以0.9指數每兩輪衰減一次。訓練結果如下:
在這裏插入圖片描述
測試準確率收斂到了0.72,也算好了很多了。同時也可以看出這個數據集確實沒有MNIST那麼簡單。
但在訓練時可以發現,過擬合現象很嚴重,訓練集預測準確率可以達到0.99。通過設置weight_dacay,添加DropOut,也沒有很大的改善。看來只是按照普通的方法進行卷積還是有缺陷。

添加STN與inception

數據集中的圖片每一個字符可能不是正的,這時候就需要網絡具有旋轉不變性。考慮到這一點,我在網絡輸入圖片的時候添加了STN模塊。同時我們識別時也需要從不同尺度看這張圖片,然後通過特徵融合得到不同尺度的特徵,所以我添加了inception的思路。
在這裏插入圖片描述
圖爲STN模塊,具體細節可自行查找。
我使用具體網絡結構如下:

輸入先通過STN模塊調整方向,然後分爲兩路:
第一路:小卷積核
Conv2d(1, 16, kernel_size=3, padding=1)
Conv2d(16, 32, kernel_size=3, padding=1)
第二路:大卷積核
Conv2d(1, 16, kernel_size=7, padding=3)
Conv2d(16, 32, kernel_size=7, padding=3)
將兩路的特徵堆疊連接,送入如下卷積層
Conv2d(64, 128, kernel_size=3, padding=1)
Conv2d(128, 256, kernel_size=3, padding=1)
Conv2d(256, 512, kernel_size=3)
Conv2d(512, 1024, kernel_size=3)
Conv2d(1024, 2048, kernel_size=3)
這時特徵已經縮減爲一個向量,直接送入如下全連接層:
Full_connect(2048, 2048)
Full_connect(2048, 1623)
各層均用ReLU激活,全連接層用DropOut防止過擬合

訓練設置weight_decay爲1e-4,初始學習率爲1e-4,並以0.9指數每兩輪衰減一次。訓練結果如下:
在這裏插入圖片描述
可見測試集準確率上升至0.75左右,上升不是很明顯。

改變loss函數

目前目標識別領域常用的loss函數除了交叉熵損失,還有focal loss,該損失是交叉熵的拓展,往往比交叉熵有更好的效果。通常無用的易分反例樣本會使得模型的整體學習方向跑偏,導致無效學習,所以該損失通過調整權重降低這些樣本的影響,如下:
在這裏插入圖片描述
在這裏插入圖片描述
γ\gamma等於0時,該損失退化爲交叉熵。
使用該損失函數,同時使用之前最簡單的卷積結構,訓練設置weight_decay爲1e-4,初始學習率爲1e-4,並以0.9指數每兩輪衰減一次,γ\gamma取2。訓練結果如下:
在這裏插入圖片描述
可見準確率也上升至了0.75左右。

加深網絡

看來這並不是簡單的任務,我們通過更深的網絡進行嘗試:
直接使用未進行預訓練的ResNet50的結構,將輸出全連接的最後輸出通道改爲1,維度改爲1623,訓練結果如下:
在這裏插入圖片描述
可見效果良好,可以達到0.99準確率,但是收斂慢,訓練慢(畢竟太深了)。在1080ti上跑50個epoch用了半小時。

小樣本學習

一頓亂試之後,我們該靜下來想想爲什麼了,有沒有方法能夠花費較少的時間快速收斂且執行效率高呢?
其實,該任務屬於小樣本學習,即樣本量非常少。目前,解決該難題的方法大致有如下四種:
1、度量學習(metric learning)
2、數據增強(data augmentation)
3、元學習(meta learning)
4、語義的方法(semantic)
我們一一來解釋一下:

度量學習(metric learning)

即將待檢測樣本通過神經網絡Embeding到另一個空間域內,在該空間中,每個樣本爲一個高維點,高維點之間距離越近代表這兩個樣本越可能是同一個類別。距離可以取各種距離,這也就是其名稱“度量”之意。神經網絡需要學習的也就是這樣的一個映射,這裏有一些有名的損失函數如triplet loss與reconstructive loss。

數據增強(data augmentation)

這個應該不用多說了吧,就是通過各種騷方法擴充數據集增加可識別率。

元學習(meta learning)

這應該也是目前的一個熱點,包含面較爲廣泛,其根本用意就是我們常聽到的“learning to learn”。他包含的方法有孿生網絡、原型網絡以及一些其他的方法。

語義的方法(semantic)

由於小樣本學習困難的本質還是在於信息不夠多,我們就想辦法引入一些語義的信息來幫助分類。

下面,我們將運用數據增強以及原型網絡來試一下下:

升級方法嘗試

數據增強

由於對於字符,鏡面翻轉與隨機旋轉都不行,我採用了對每一個字符進行開閉運算的方法將數據集擴充了一倍,這裏使用每類中35張圖片作爲訓練集,5張圖片作爲測試集。使用上一節中最簡單的卷積網絡訓練,使用focal loss,結果如下:
在這裏插入圖片描述
非常amazing啊,測試集準確率訓練10個epoch時達到了0.9,18epoch時達到了0.99。收斂如此之快讓我們領會到了數據的重要性。

原型網絡

原型網絡爲解決小樣本學習的元學習方法中的一種,我這裏運用了原型網絡最初的論文:Prototypical Networks for Few-shot Learning中的方法,簡要介紹一下:
在這裏插入圖片描述
如圖,神經網絡學習一個空間Embeding,將數據映射到另一空間,然後求同一類別的均值,作爲該類別的原型。如下:
在這裏插入圖片描述
其中f即爲該神經網絡:
在這裏插入圖片描述
然後引入一個新的數據,判斷其到每個原型的歐氏距離的softmax值,作爲其屬於該類別的概率:
在這裏插入圖片描述
損失函數要做的就是最大化正確識別時的這個概率,如下:
在這裏插入圖片描述
運用這樣的方法,我使用的網絡結構如下:

Conv2d(1,64,kernal_size=3,stride=2,padding=1),BatchNorm(),ReLU()
Conv2d(64,128,kernal_size=3,stride=2,padding=1),BatchNorm(),ReLU()
Conv2d(128,256,kernal_size=3,stride=2,padding=1),BatchNorm(),ReLU()
Conv2d(256,512,kernal_size=3,stride=2,padding=1),BatchNorm(),ReLU()
Flatten()

然後進行訓練,結果更加amazing:
一輪直接收斂:
每類10個做訓練集,10個做測試集時,訓練一輪後測試集準確率到達0.988,後面最高到達0.99.
每類2個做訓練集,18個做測試集時,一輪訓練後測試集準確率到達0.96,後面最高到達0.97。

總結

完成了老師佈置的作業,終於能去快樂地玩耍了emm。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章