使用Dl4j訓練的一個手寫數字識別軟件

DL4J使用之手寫數字識別

最近一直在學習深度學習,由於我是Java程序員出身,就選擇了一個面向Java的深度學習庫—DL4J。爲了更加熟練的掌握這個庫的使用,我使用該庫,以MNIST(http://yann.lecun.com/exdb/mnist/)手寫數字數據集作爲基礎,訓練了一個模型,來識別手寫字體。下面我們從以下幾個方面講解該項目的實現:

DL4J簡介

Deeplearning4j是國外創業公司Skymind的產品。目前最新的版本更新到了0.7.2。源碼全部公開並託管在github上(https://github.com/deeplearning4j/deeplearning4j)。從這個庫的名字上可以看出,它就是轉爲Java程序員寫的Deep Learning庫。其實這個庫吸引人的地方不僅僅在於它支持Java,更爲重要的是它可以支持Spark。由於Deep Learning模型的訓練需要大量的內存,而且原始數據的存儲有時候也需要很大的外存空間,所以如果可以利用集羣來處理便是最好不過了。當然,除了Deeplearning4j以外,還有一些Deep Learning的庫可以支持Spark,比如yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近開源的BigDL。這些庫我自己都沒怎麼用過,所以就不多說了,這裏重點說說Deeplearning4j的使用。
從項目管理角度,DL4J官方給的例子中,推薦使用Maven構建項目,但是目前在學習階段,我是直接從官網扣下來了需要的Jar包導入項目,這樣有一個好處,在項目遷移到別的計算機上運行的時候不需要等待Maven下載jar包的時間。當然,工作中還是推薦使用Maven。不說了,下面是我提出來的Jar包:
這裏寫圖片描述
看着還是挺龐大的,其實也難怪,畢竟深度學習需要大量的工作才能形成一個庫。這些我已經上傳到CSDN可以點擊下方鏈接下載(https://download.csdn.net/download/yushengpeng/10286975

模型的訓練

訓練數據集(MNIST)

MNIST 數據集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字數據。它有60000個訓練樣本集和10000個測試樣本集。MNIST算是深度學習入門的一個數據集吧,也是一個比較優秀的手寫數字數據集,可以用於半監督學習,並且取得了非常不錯的成績。下面是該數據集的部分截圖:
這裏寫圖片描述
關於如何將該數據集轉換成DL4J能識別的格式,請學習DL4J的官方文檔。我也上傳了Dl4J的官方文檔到了CSND,如果你有需求請前去下載(https://download.csdn.net/download/yushengpeng/10287018)。

模型架構

當我們正確讀取數據後,我們需要定義具體的神經網絡結構,這裏我用的是Lenet,該網絡是一個5層的神經網絡(在深度學習中,我們約定俗成的認爲輸入層是第0層不參與層數統計),該網絡各層情況如下:

第0層: nput layer: 輸入數據爲原始訓練圖像
第1層: Conv1:6個5*5的卷積核,步長Stride爲1
第2層:Pooling1:卷積核size爲2*2,步長Stride爲2
第3層:Conv2:12個5*5的卷積核,步長Stride爲1
第4層:Pooling2:卷積核size爲2*2,步長Stride爲2
第5層:Output layer:輸出爲10維向量

網絡層級結構示意圖如下:
這裏寫圖片描述

Deeplearning4j的實現參考了官網(https://github.com/deeplearning4j/dl4j-examples)的例子。具體代碼如下:

public class CNN_MNIST {
    private static Logger log = LoggerFactory.getLogger(CNN_MNIST.class);
    public static void main(String[] args) throws IOException {
        int nChannels = 1;
        int outputNum = 10; // The number of possible outcomes
        int batchSize = 64; // Test batch size
        int nEpochs = 2; // Number of training epochs
        int iterations = 1; // Number of training iterations
        int seed = 123; //
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations)
                .regularization(true).l2(0.0005).learningRate(.01).weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS)
                .momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5)
                        // nIn and nOut specify depth. nIn here is the nChannels and
                        // nOut is the number of filters to be applied
                        .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())
                .layer(1,
                        new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2)
                                .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        // Note that nIn need not be specified in later layers
                        .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build())
                .layer(3,
                        new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2)
                                .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())
                .layer(5,
                        new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum)
                                .activation(Activation.SOFTMAX).build())
                .setInputType(InputType.convolutionalFlat(28, 28, 1)) // See note below
                .backprop(true).pretrain(false).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();

        model.setListeners(new ScoreIterationListener(1));
        for (int i = 0; i < nEpochs; i++) {
            model.fit(mnistTrain);
            log.info("*** Completed epoch {} ***", i);
            log.info("Evaluate model....");
            Evaluation eval = new Evaluation(outputNum);
            while (mnistTest.hasNext()) {
                DataSet ds = mnistTest.next();
                INDArray output = model.output(ds.getFeatureMatrix(), false);
                eval.eval(ds.getLabels(), output);
            }
            log.info(eval.stats());
            mnistTest.reset();
            log.info("****************Example finished********************");

            log.info("******SAVE TRAINED MODEL******");
            // Details

            // Where to save model
            File locationToSave = new File("trained_mnist_model.zip");

            // boolean save Updater
            boolean saveUpdater = false;

            // ModelSerializer needs modelname, saveUpdater, Location

            ModelSerializer.writeModel(model, locationToSave, saveUpdater);

        }
    }
}

可以發現,神經網絡需要定義很多的超參數,學習率、正則化係數、卷積核的大小、激勵函數等都是需要人爲設定的。不同的超參數,對結果的影響很大,其實後來發現,很多時間都花在數據處理和調參方面。畢竟自己設計網絡的能力有限,一般都是參考大牛的論文,然後自己照葫蘆畫瓢地實現。這裏實現的Lenet的結構是:卷積–>下采樣–>卷積–>下采樣–>全連接。和原論文的結構基本一致。卷積核的大小也是參考的原論文。具體細節可參考之前發的論文鏈接。這裏我們設置了一個Score的監聽事件,主要是可以在訓練的時候獲取每一次權重更新後損失函數的收斂情況,如下面所示:
這裏寫圖片描述

模型性能

of classes: 10
Accuracy 0.9918
Precision 0.9917
Recall 0.9917
F1 Score 0.9917

模型性能還是不錯的,在10000個手寫數字測試集上的準確率能達到99.17%。當然,模型的好壞跟神經網絡的架構,超參的設置都有關係,關於到底選用什麼樣的模型架構需要更多的經驗,知識。一般具體問題具體分析。

模型的保存與加載

當我們訓練好了一個模型的時候,我們需要將訓練好的模型持久化到本地磁盤,或者其他存儲介質。因爲訓練模型是一個非常耗時的工作,模型的大小,數據集的大小,訓練一個模型需要一天,一週,一個月,甚至是更長的時間。我們不可能每次在實際的項目中,需要的時候再去訓練出一個模型。DL4J也爲我們實現了模型的持久化功能,具體代碼如下:

File locationToSave = new File("trained_mnist_model.zip");//保存路徑,存儲位置
boolean saveUpdater = false;
ModelSerializer.writeModel(model, locationToSave, saveUpdater);

當然持久化模型是爲了再次加載模型,使用模型。DL4J也爲我們實現了模型的的加載功能,具體代碼如下:

NativeImageLoader loader = new NativeImageLoader(28, 28, 1);
INDArray image = loader.asMatrix(new File("XXX://test.jpg"));//從本地磁盤中加載文件
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
scaler.transform(image);
INDArray output = model.output(image);//對圖片進行分類預測

結果展示

這裏寫圖片描述
這裏寫圖片描述
這裏寫圖片描述
這裏寫圖片描述

這裏寫圖片描述
這裏寫圖片描述
這裏寫圖片描述
這裏寫圖片描述
這裏寫圖片描述
這裏寫圖片描述

總結 與展望

通過這個小項目,我參照官網手冊,初步實現了LENET網絡。並取得了不錯的成果。當然,我也是學習了充足的理論之後,再來學習DL4J這個深度學習框架的。關於這個項目的源碼,你可以去我的GutHub上下載:https://github.com/ShengPengYu/writtingRecoginition
。該項目還有不足之處,比方說可以邊測試邊學習,我們在發現我們書寫的測試數據分類不準確的時候,可以加入到訓練數據庫,在線對模型實時訓練,因爲每個用戶的書寫風格不一樣,可能對分類結果有一定的影響。邊測試邊訓練,可以訓練出符合用戶個人情況的模型。還有一種情況是。當然我也有一定的思考,比方說,如果我對目前模型進一步改進,做一個漢字識別項目,那麼最後一層該使用什麼架構,中國漢字那麼多,如果使用one-hot模式,會不會維度太大,在時間複雜度和空間複雜度上是一個非常嚴峻的問題。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章