使用Dl4j训练的一个手写数字识别软件

DL4J使用之手写数字识别

最近一直在学习深度学习,由于我是Java程序员出身,就选择了一个面向Java的深度学习库—DL4J。为了更加熟练的掌握这个库的使用,我使用该库,以MNIST(http://yann.lecun.com/exdb/mnist/)手写数字数据集作为基础,训练了一个模型,来识别手写字体。下面我们从以下几个方面讲解该项目的实现:

DL4J简介

Deeplearning4j是国外创业公司Skymind的产品。目前最新的版本更新到了0.7.2。源码全部公开并托管在github上(https://github.com/deeplearning4j/deeplearning4j)。从这个库的名字上可以看出,它就是转为Java程序员写的Deep Learning库。其实这个库吸引人的地方不仅仅在于它支持Java,更为重要的是它可以支持Spark。由于Deep Learning模型的训练需要大量的内存,而且原始数据的存储有时候也需要很大的外存空间,所以如果可以利用集群来处理便是最好不过了。当然,除了Deeplearning4j以外,还有一些Deep Learning的库可以支持Spark,比如yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近开源的BigDL。这些库我自己都没怎么用过,所以就不多说了,这里重点说说Deeplearning4j的使用。
从项目管理角度,DL4J官方给的例子中,推荐使用Maven构建项目,但是目前在学习阶段,我是直接从官网扣下来了需要的Jar包导入项目,这样有一个好处,在项目迁移到别的计算机上运行的时候不需要等待Maven下载jar包的时间。当然,工作中还是推荐使用Maven。不说了,下面是我提出来的Jar包:
这里写图片描述
看着还是挺庞大的,其实也难怪,毕竟深度学习需要大量的工作才能形成一个库。这些我已经上传到CSDN可以点击下方链接下载(https://download.csdn.net/download/yushengpeng/10286975

模型的训练

训练数据集(MNIST)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。它有60000个训练样本集和10000个测试样本集。MNIST算是深度学习入门的一个数据集吧,也是一个比较优秀的手写数字数据集,可以用于半监督学习,并且取得了非常不错的成绩。下面是该数据集的部分截图:
这里写图片描述
关于如何将该数据集转换成DL4J能识别的格式,请学习DL4J的官方文档。我也上传了Dl4J的官方文档到了CSND,如果你有需求请前去下载(https://download.csdn.net/download/yushengpeng/10287018)。

模型架构

当我们正确读取数据后,我们需要定义具体的神经网络结构,这里我用的是Lenet,该网络是一个5层的神经网络(在深度学习中,我们约定俗成的认为输入层是第0层不参与层数统计),该网络各层情况如下:

第0层: nput layer: 输入数据为原始训练图像
第1层: Conv1:6个5*5的卷积核,步长Stride为1
第2层:Pooling1:卷积核size为2*2,步长Stride为2
第3层:Conv2:12个5*5的卷积核,步长Stride为1
第4层:Pooling2:卷积核size为2*2,步长Stride为2
第5层:Output layer:输出为10维向量

网络层级结构示意图如下:
这里写图片描述

Deeplearning4j的实现参考了官网(https://github.com/deeplearning4j/dl4j-examples)的例子。具体代码如下:

public class CNN_MNIST {
    private static Logger log = LoggerFactory.getLogger(CNN_MNIST.class);
    public static void main(String[] args) throws IOException {
        int nChannels = 1;
        int outputNum = 10; // The number of possible outcomes
        int batchSize = 64; // Test batch size
        int nEpochs = 2; // Number of training epochs
        int iterations = 1; // Number of training iterations
        int seed = 123; //
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations)
                .regularization(true).l2(0.0005).learningRate(.01).weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS)
                .momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5)
                        // nIn and nOut specify depth. nIn here is the nChannels and
                        // nOut is the number of filters to be applied
                        .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())
                .layer(1,
                        new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2)
                                .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        // Note that nIn need not be specified in later layers
                        .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build())
                .layer(3,
                        new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2)
                                .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())
                .layer(5,
                        new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum)
                                .activation(Activation.SOFTMAX).build())
                .setInputType(InputType.convolutionalFlat(28, 28, 1)) // See note below
                .backprop(true).pretrain(false).build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();

        model.setListeners(new ScoreIterationListener(1));
        for (int i = 0; i < nEpochs; i++) {
            model.fit(mnistTrain);
            log.info("*** Completed epoch {} ***", i);
            log.info("Evaluate model....");
            Evaluation eval = new Evaluation(outputNum);
            while (mnistTest.hasNext()) {
                DataSet ds = mnistTest.next();
                INDArray output = model.output(ds.getFeatureMatrix(), false);
                eval.eval(ds.getLabels(), output);
            }
            log.info(eval.stats());
            mnistTest.reset();
            log.info("****************Example finished********************");

            log.info("******SAVE TRAINED MODEL******");
            // Details

            // Where to save model
            File locationToSave = new File("trained_mnist_model.zip");

            // boolean save Updater
            boolean saveUpdater = false;

            // ModelSerializer needs modelname, saveUpdater, Location

            ModelSerializer.writeModel(model, locationToSave, saveUpdater);

        }
    }
}

可以发现,神经网络需要定义很多的超参数,学习率、正则化系数、卷积核的大小、激励函数等都是需要人为设定的。不同的超参数,对结果的影响很大,其实后来发现,很多时间都花在数据处理和调参方面。毕竟自己设计网络的能力有限,一般都是参考大牛的论文,然后自己照葫芦画瓢地实现。这里实现的Lenet的结构是:卷积–>下采样–>卷积–>下采样–>全连接。和原论文的结构基本一致。卷积核的大小也是参考的原论文。具体细节可参考之前发的论文链接。这里我们设置了一个Score的监听事件,主要是可以在训练的时候获取每一次权重更新后损失函数的收敛情况,如下面所示:
这里写图片描述

模型性能

of classes: 10
Accuracy 0.9918
Precision 0.9917
Recall 0.9917
F1 Score 0.9917

模型性能还是不错的,在10000个手写数字测试集上的准确率能达到99.17%。当然,模型的好坏跟神经网络的架构,超参的设置都有关系,关于到底选用什么样的模型架构需要更多的经验,知识。一般具体问题具体分析。

模型的保存与加载

当我们训练好了一个模型的时候,我们需要将训练好的模型持久化到本地磁盘,或者其他存储介质。因为训练模型是一个非常耗时的工作,模型的大小,数据集的大小,训练一个模型需要一天,一周,一个月,甚至是更长的时间。我们不可能每次在实际的项目中,需要的时候再去训练出一个模型。DL4J也为我们实现了模型的持久化功能,具体代码如下:

File locationToSave = new File("trained_mnist_model.zip");//保存路径,存储位置
boolean saveUpdater = false;
ModelSerializer.writeModel(model, locationToSave, saveUpdater);

当然持久化模型是为了再次加载模型,使用模型。DL4J也为我们实现了模型的的加载功能,具体代码如下:

NativeImageLoader loader = new NativeImageLoader(28, 28, 1);
INDArray image = loader.asMatrix(new File("XXX://test.jpg"));//从本地磁盘中加载文件
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
scaler.transform(image);
INDArray output = model.output(image);//对图片进行分类预测

结果展示

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

总结 与展望

通过这个小项目,我参照官网手册,初步实现了LENET网络。并取得了不错的成果。当然,我也是学习了充足的理论之后,再来学习DL4J这个深度学习框架的。关于这个项目的源码,你可以去我的GutHub上下载:https://github.com/ShengPengYu/writtingRecoginition
。该项目还有不足之处,比方说可以边测试边学习,我们在发现我们书写的测试数据分类不准确的时候,可以加入到训练数据库,在线对模型实时训练,因为每个用户的书写风格不一样,可能对分类结果有一定的影响。边测试边训练,可以训练出符合用户个人情况的模型。还有一种情况是。当然我也有一定的思考,比方说,如果我对目前模型进一步改进,做一个汉字识别项目,那么最后一层该使用什么架构,中国汉字那么多,如果使用one-hot模式,会不会维度太大,在时间复杂度和空间复杂度上是一个非常严峻的问题。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章