在日常生活中,數據的數量並不是相等的。即使是在超大型數據集中,數據的數量差異也廣泛存在,例如下圖中SUN-397中的數據分佈情況。臥室的數據可以達到1000以上,但圖書館甚至不到50。
在本文中,這種數據分佈情況被稱爲“長尾”(long tail)。對於處在尾部的數據而言,機器學習算法實際上是在執行小樣本學習任務(N-shot learning),這直接導致了相關領域識別的準確率不高。
爲解決這一問題,本文提出了一種基於元學習的方法。其思想入下圖所示。對於具有較多樣本的數據,使用基於殘差網絡(ResNet的思想)的元學習器學習這些學習的進程(Learn to learn),從而使用較少數據實現快速學習。
而他們的具體做法是如下圖所示。他們使用的是一個N個Block的殘差網絡,每個模塊有獨立的參數。他們的具體設想是:第i個殘差模塊的輸入是由個數據訓練所得的CNN網絡的參數(這裏他們保持CNN前部特徵提取網絡的參數不變,只將最後一層4096->1000中每一個4096->1的參數θ作爲了Transfer Learning的對象)。也即,第個Block的作用就是將對應於個數據的訓練參數映射爲對應於的參數。
這種方法的最大好處是,後續的每一個class都可以選擇與自己數據量最爲契合的輸入點,通過其後數層Res網絡獲得最終的優化參數。原文中使用“幻覺”(hallucinate)形象地描述這一參數演化過程。
網絡的訓練方面,文中使用了遞歸的訓練方式:先用樣本數量大於的類訓練最後一層ResN,然後遞歸的逐漸包含前面的層。在訓練第層時,使用的所有類的樣本數量都於,這樣可以保證用於訓練的隨機選擇樣本數量小於該類總樣本量的一半。
如此這個網絡的思想就成型了:當要訓練的層由上圖中所對應的的變成所對應的,所用的訓練樣本來自的類的最小樣本數量便會下降一半,如此便有更多的類參與到了訓練過程中。這便是本文在摘要部分提出的“從頭到身體再到尾巴”的訓練方法,前面的類都會用於提高後面的類的訓練效果。
如此,網絡的思想就講完了。剩下的一些損失函數定義等請讀者看下原文吧。雖然這篇文章是NIPS,但我還是要吐槽:這從小往大寫的方法介紹真是完全讀不懂!讀到損失函數什麼的時候我本人完全是懵逼的,而上面我的講解則是直接將他的介紹倒過來了,是不是覺得清晰了很多!!!