Mnist數據集的分類問題一直被認爲是深度學習的Hello World。利用2層卷積網絡,經過若干輪的訓練後,在相應測試集上的準確率可以達到95%以上。經過調參後,甚至可以達到99%以上。其實,即使不用用卷積層提取特徵,而是用傳統的全連接網絡也同樣可以達到非常高的準確率。在Mnist數據集的官網上(http://yann.lecun.com/exdb/mnist/),除了基於神經網絡的分類器,利用傳統的分類方法,如:KNN,SVM,也都可以獲得非常好的結果。下面就是部分模型分類效果的截圖:
從以上結果分析可以發現,無論是淺層模型還是深度學習,在Mnist上的分類問題上都可以達到很高的精度,因此從某種角度也可以說,Mnist數據集複雜度不夠,或者說Mnist分類問題並不是一個具有代表性的機器視覺問題。就這個問題,《Deep Learning》一書的作者Ian Goodfellow和著名開源項目Keras的作者Francois Chollet都有自己的評述,詳情可轉到下面兩個鏈接:
1.Ian Goodfellow Commnet On Mnist DataSet
雖然Mnist可能並不是最合適入門深度學習的數據集,但是鑑於長期以來開發人員的使用習慣,想要找到完全替代Mnist的開源數據集確實有點困難,但這個難題最近有了一個比較好的解答,就是類似Mnist的一個服裝圖像數據集--Fashion Mnist。
和Mnist數據集一樣,Fashion Mnist也是28*28的灰度圖。內容涵蓋了鞋、包、衣服、褲子。它的文件名稱和數據格式和Mnist一模一樣。換句話說,你完全不需要改動你之前在Mnist上的建模邏輯,只需要把相應的文件替換掉,就可以對Fashion Mnist進行訓練和評估。不過唯一不同的是,Fashion Mnist的分類準確率遠沒有Mnist那麼高。目前在Fashion Mnist的github主頁上,最好的結果也僅僅是在95%左右。當然,如果你自己的網絡有了好的結果,可以在主頁上提個issue,也作爲是對這個數據集的一個貢獻。
下面主要介紹3個方面的內容:
1.Fashion Mnist基於CNN的建模分類與評估
2.與Mnist的比較
3.簡單的服裝分類應用
首先介紹第一部分的主要內容。對於Fashion Mnist數據集採用卷積神經網絡進行分類建模,具體的網絡結構是:2Conv+2FC。建模工具是Deeplearning4j。詳細的超參數配置見如下代碼片段:
public static MultiLayerNetwork getModel(){
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.iterations(1)
//.regularization(true).l2(0.005)
.learningRate(0.01)
.learningRateScoreBasedDecayRate(0.5)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.ADAM)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(1)
.stride(1, 1)
.nOut(32)
.activation(Activation.LEAKYRELU)
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
.stride(1, 1)
.nOut(64)
.activation(Activation.LEAKYRELU)
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(4, new DenseLayer.Builder().activation(Activation.LEAKYRELU)
.nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(10)
.activation(Activation.SOFTMAX)
.build())
.backprop(true).pretrain(false)
.setInputType(InputType.convolutionalFlat(28, 28, 1));
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
return model;
}
簡單解釋下部分超參數:激勵函數部分主要用的是LeakeyRelu。
學習率用了Decay的策略。Decay的幅度是50%。
正則化項是可選的(經測試,正則化項在如上配置中,影響不大)
網絡結構:2 Conv-with-MaxPooling + 2FC。卷積層中每層的featureMap的數量如上述所示。
除了建模的部分,數據的ETL部分同樣很重要。在具體實現中,我直接利用Deeplearning4j自帶的一個解析Mnist數據集的組件:MnistManager。它的主要功能就是讀取解壓後的二進制Mnist數據集以及相應的分類標籤。由於Fashion Mnist和原始Mnist數據集在數據格式上完全相同,所以可以直接使用Mnist的組件進行解析。在讀取的時候,我們可以根據要求設置batchSize,一個batch的數據和標籤會封裝在一個DataSet對象中。由這些DataSet構成的迭代器即可作爲最終訓練或者測試的數據。下面具體看下以上邏輯的實現:
public static DataSet fetch(int batchSize , boolean binarize, MnistManager man, boolean save, boolean train) {
float[][] featureData = new float[batchSize][0];
float[][] labelData = new float[batchSize][0];
int actualExamples = 0;
for (int i = 0; i < batchSize && cursor < totalExamples; i++, cursor++) {
byte[] img = man.readImageUnsafe(order[cursor]);
int label = man.readLabel(order[cursor]);
float[] featureVec = new float[img.length];
featureData[actualExamples] = featureVec;
labelData[actualExamples] = new float[10];
labelData[actualExamples][label] = 1.0f;
for (int j = 0; j < img.length; j++) {
float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned
if (binarize) {
if (v > 30.0f)
featureVec[j] = 1.0f;
else
featureVec[j] = 0.0f;
} else {
featureVec[j] = v / 255.0f;
}
}
if( save ){
Mat mat = new Mat(28, 28, CV_8SC1, new BytePointer(img));
if( train )
JavaCVUtil.imWrite(mat, "FashionMnist/trainData/" + label + "_" + cursor + ".jpg");
else
JavaCVUtil.imWrite(mat, "FashionMnist/testData/" + label + "_" + cursor + ".jpg");
}
actualExamples++;
}
if (actualExamples < batchSize) {
featureData = Arrays.copyOfRange(featureData, 0, actualExamples);
labelData = Arrays.copyOfRange(labelData, 0, actualExamples);
}
INDArray features = Nd4j.create(featureData);
INDArray labels = Nd4j.create(labelData);
return new DataSet(features, labels);
}
public static DataSetIterator getData(String dir, boolean train , int batchSize, boolean save) throws IOException{
String featureFileDir = dir;
String labelFileDir = dir;
cursor = 0;
if( train ){
featureFileDir += "train-images-idx3-ubyte";
labelFileDir += "train-labels-idx1-ubyte";
totalExamples = 60000;
order = new int[totalExamples];
}else{
featureFileDir += "t10k-images-idx3-ubyte";
labelFileDir += "t10k-labels-idx1-ubyte";
totalExamples = 10000;
order = new int[totalExamples];
}
for (int i = 0; i < order.length; i++)order[i] = i;
MathUtils.shuffleArray(order, 123456L); //shuffle order
MnistManager man = new MnistManager(featureFileDir, labelFileDir, train);
List<DataSet> res = new LinkedList<DataSet>();
while(cursor < totalExamples){
res.add(fetch(batchSize, false, man, save, train));
}
ExistingDataSetIterator iter = new ExistingDataSetIterator(res);
return iter;
}
以上兩個靜態方法就是解析數據、讀取標籤、封裝數據並生成可迭代數據集的過程。其中,getData這個方法可以根據參數的不同,生成訓練或者測試數據集。在fetch這個方法裏,可以選擇是二值化還是正常歸一化。我這裏選擇的是正常歸一化。此外,爲了方便看到Fashion Mnist的圖像形式,可以選擇是否以圖片的形式生成這些圖片。如果生成圖片的話,則可以看到下面這些圖:
這些圖片我會上傳到CSDN上供大家下載。下載連接。
從截圖可以看出是衣服和褲子兩個品類。文件名中的第一個數字是這張圖片的分類標籤。這樣方便直接從圖片進行建模。訓練集共6W張圖片,測試集共1W張圖片。
到此的話,數據的ETL和建模的步驟都已經完成,下面就是對模型參數進行訓練。這裏我還是用的GPU來訓練模型。顯卡是Telsa K80。單卡進行訓練。相應的CUDA版本是8.0。具體的訓練邏輯可見下面代碼片段:
public static void main(String[] args)throws IOException {
DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
final int numEpochs = Integer.parseInt(args[0]);
final int batchSize = Integer.parseInt(args[1]);
final String modelSavePath = args[2];
final String dataPath = args[3];
CudaEnvironment.getInstance().getConfiguration()
// 是否允許多卡
.allowMultiGPU(false)
.useDevice(7)
// 顯存大小
.setMaximumDeviceCache(11L * 1024L * 1024L * 1024L)
// 是否允許多卡直接數據的直接訪問
.allowCrossDeviceAccess(false);
DataSetIterator trainData = getData(dataPath+ "/", true, batchSize, false);
DataSetIterator testData = getData(dataPath+ "/" , false, batchSize, false);
MultiLayerNetwork model = getModel();
for( int i = 0; i < numEpochs; ++i ){
model.fit(trainData);
System.out.println("Epoch :" + i + " Finish");
System.out.println("Score: " + model.score());
Evaluation eval = model.evaluate(testData);
System.out.println(eval.stats());
System.out.println();
}
Evaluation eval = model.evaluate(testData);
System.out.println(eval.stats());
ModelSerializer.writeModel(model, modelSavePath, true);
}
其中batchSize等可以通過args參數傳入設置。注意,最後我們把模型進行了保存。在每一輪的訓練後,我們都打印了損失函數的值,並同時在測試集上評估了此時模型的準確性。我們一共訓練了100輪。下面給出部分訓練過程中的模型信息:Epoch :0 Finish
Score: 0.4417036887606827
==========================Scores========================================
Accuracy: 0.7986
Precision: 0.8004
Recall: 0.7986
F1 Score: 0.7995
========================================================================
第一輪的loss值和模型評估。可以說,效果不佳。和Mnist的第一輪相差甚遠(下面會有Mnist的相應訓練信息)。Epoch :99 Finish
Score: 0.020469321075697797
==========================Scores========================================
Accuracy: 0.9072
Precision: 0.9088
Recall: 0.9072
F1 Score: 0.908
========================================================================
100輪訓練完之後,勉強達到了90%左右。應該說,結果一般。到這裏,我就沒有再訓練下去了。那麼到此,第一部分的主要工作就完成了。最終經過100輪的訓練,loss值達到0.02,模型的準確率在90%。
接着介紹下第二部分,也就是和Mnist比較的內容。
Mnist的訓練過程和上面的一模一樣,唯一不同的是,數據集換成Mnist的就可以了。同樣經過100輪的訓練,我們來看下對比結果。
Mnist DataSet | Fashion Mnist DataSet | |
Epoch 1 | ==========================Scores==================================== Accuracy: 0.9545 Precision: 0.955 Recall: 0.954 F1 Score: 0.9545 ==================================================================== | ==========================Scores=================================== Accuracy: 0.7986 Precision: 0.8004 Recall: 0.7986 F1 Score: 0.7995 =================================================================== |
Epoch 100 | ==========================Scores==================================== Accuracy: 0.9922 Precision: 0.9921 Recall: 0.9921 F1 Score: 0.9921 ==================================================================== | ==========================Scores===================================== Accuracy: 0.9072 Precision: 0.9088 Recall: 0.9072 F1 Score: 0.908 ===================================================================== |
Mnist數據集很容易就達到了95%的準確率,甚至最後達到了99.22%。然而Fashion Mnist最終也只有徘徊在90%上下。由此可見,Fashion Mnist數據集的分類問題更爲複雜。2層卷積神經網絡的效果可能也就是在90%左右了(PS:這個講述並沒有什麼理論依據,但從github主頁看到他人用Keras搭建類似結構的網絡來訓練Fashion Mnist,也是在90%上下,所以作此推測,僅僅是實驗結果)。
最後一個部分介紹下基於剛纔訓練的模型如何搭建一個Web應用。
服裝的分類場景在各大電商企業中有很多應用。雖然不一定需要準確區分運動鞋和休閒鞋,但是區分衣服、褲子、包、鞋還是很有必要的。這個場景在圖像檢索等應用方面有着類似文本檢索中Query分析的作用,最終可以減少索引的查詢量。這裏就直接利用這樣的一個開源數據集搭建一個Web服務,用於識別圖片中物品的所屬品類。涉及到的工具有Spring、Tomcat,JSP,還有之前提到的Deeplearning4j和Nd4j。
我在本地的Eclipse中配置了Tomcat的插件、服務的端口號、上下文的根路徑等。在POM文件中引入了Spring和Deeplearning4j的相關依賴。最後前端頁面上做了個簡單的上傳圖片的按鈕,最後的模型分類結果會和圖片一起在頁面上做展示。由於這裏面涉及了關於J2EE開發的諸多細節,和主要介紹的內容有些偏離,所以這裏僅僅介紹主要的思路。在後面的文章中,如果有機會的話會詳細介紹Deeplearning4j訓練的模型上線部署的一些方式,當然也包括一些採坑的地方。下面就給出一些示例結果:
這裏有個地方需要注意:Fashion Mnist是28*28的灰度圖。在做這個實際應用場景的時候,我同樣對這些彩色圖片做了灰度化以及resize的處理。換句話說,和訓練數據保持一致對預測結果也同樣重要。
最後對上面的內容做下小結。Fashion Mnist作爲Mnist的替代數據集,無論在數據格式還是文件名稱上都和原始的Mnist保持了高度一致,從而方便研發人員遷移之前的工作。但是,Fashion Mnist的分類比Mnist更有挑戰性,至少從目前github主頁上最優結果以及我自己的實驗來看,很難達到和Mnist一樣的準確性。原因的話,像外套、襯衫;靴子、運動鞋;難免存在外形極其相似的情況。因此,誤判的情況會比較多。不過,從另一個角度說,這也說明相比Mnist,Fashion Mnist數據集的分類問題更爲複雜。此外,Fashion Mnist也可以作爲諸多電商企業商品圖片分類的一個demo級別的測試數據集。通過做服裝商品分類這樣一個應用,可以對深度學習在產品級別應用的問題上有感性的認識,更重要的可能是發現深度神經網絡的侷限性,並非是萬能的。這可能也是Fashion Mnist相比於原始Mnist數據集的價值所在,讓大家對深度學習有理性的認識(原始Mnist很容易達到98%-99%的準確率,容易誤導大家覺得深度學習就是這樣準確,其實數據集本身也有非常大的關係,不能僅僅依靠模型)。
--------------以下更新自2018/3/21
在Deeplearning4j的QQ羣裏還有這篇文章的留言區有同學希望我補充下Web部分的代碼,特此在這裏做些補充。
首先,Web容器我用的是Tomcat-Eclipse的插件,在pom裏的配置如下:
<build>
<finalName>dl-webapp</finalName>
<plugins>
<plugin>
<groupId>org.apache.tomcat.maven</groupId>
<artifactId>tomcat7-maven-plugin</artifactId>
<version>2.2</version>
<configuration>
<port>8080</port>
<path>/maven-web-demo</path>
<uriEncoding>UTF-8</uriEncoding>
<finalName>maven-web-demo</finalName>
<server>tomcat7</server>
</configuration>
<executions>
<!-- 打包成功後即開始運行web容器 -->
<execution>
<phase>package</phase>
<goals>
<goal>run</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
其次,整個Web工程的編譯目錄結構如下(工程名:DL):