Deeplearning4j 手写体数字识别

最近这几年,深度学习很火,包括本身在内的不少对机器学习仍是只知其一;不知其二的小白也开始用深度学习作些应用。因为小白的等级不高,算法本身写不出来,因此就用了开源库。Deep Learning的开源库有多,若是以语言来划分的话,就有Python系列的tensowflow,theano,keras,C/C++系列的Caffe,还有Lua系列的torch等等。但我们公司是用Java为主,大部分项目最终也是作成一个Java Web的服务,因此我最终选择了Deeplearning4j。java

    Deeplearning4j是国外创业公司Skymind的产品。目前最新的版本更新到了0.7.2。源码所有公开并托管在github上(https://github.com/deeplearning4j/deeplearning4j)。从这个库的名字上能够看出,它就是转为Java程序员写的Deep Learning库。其实这个库吸引人的地方不只仅在于它支持Java,更为重要的是它能够支持Spark。因为Deep Learning模型的训练须要大量的内存,并且原始数据的存储有时候也须要很大的外存空间,因此若是能够利用集群来处理即是最好不过了。固然,除了Deeplearning4j之外,还有一些Deep Learning的库能够支持Spark,好比yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近开源的BigDL。这些库我本身都没怎么用过,因此就很少说了,这里重点说说Deeplearning4j的使用。python

    通常开始使用别人的代码库,都会先跑一些demo,或者说Hello World的例子,就好像学习一门编程语言同样,第一行代码都是打印Hello World。Deep Learning的Hello World的例子通常是两个,一个是Mnist数据集的分类,另外一个就是Word2Vec找类似词。因为Word2Vec并非严格意义上的深度神经网络,所以这里就用Lenet网络处理Mnist数据集来做为Deep Learning的Hello World。Mnist是开源的28x28的黑白手写体数字图片集(http://yann.lecun.com/exdb/mnist/),其中包含6W张训练图片和1W张测试图片。至于Lenet的相关结构描述,能够参考这个连接:http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf。下面就详细讲述下,利用Deeplearning4j如何进行建模、训练和预测评估。git

    首先,咱们创建一个maven项目。而后在pom文件里加入Deeplearning4j的一些相关依赖。最主要的有三个:deeplearning4j-core,datavec,nd4j。deeplearning4j-core是神经网络结构实现的代码,nd4j是用于作张量运算的库,经过JavaCPP来调用编译好的C++库(可选:ATAL, MKL, 和OpenBLAS),datavec则主要负责数据的ETL。具体可见代码:程序员

<properties>  
  <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>  
  <nd4j.version>0.7.1</nd4j.version>  
  <dl4j.version>0.7.1</dl4j.version>  
  <datavec.version>0.7.1</datavec.version>  
  <scala.binary.version>2.10</scala.binary.version>  
</properties>  
<dependencies>  
<dependency>  
    <groupId>org.nd4j</groupId>  
    <artifactId>nd4j-native</artifactId>   
    <version>${nd4j.version}</version>  
</dependency>  
<dependency>  
    <groupId>org.deeplearning4j</groupId>  
    <artifactId>dl4j-spark_2.11</artifactId>  
    <version>${dl4j.version}</version>  
</dependency>  
     <dependency>  
          <groupId>org.datavec</groupId>  
          <artifactId>datavec-spark_${scala.binary.version}</artifactId>  
          <version>${datavec.version}</version>  
    </dependency>  
      <dependency>  
   <groupId>org.deeplearning4j</groupId>  
   <artifactId>deeplearning4j-core</artifactId>  
   <version>${dl4j.version}</version>  
</dependency>  
</dependencies>  
  1.     这些依赖里面有和Spark相关的,主要是跑Spark要用到。不过没有关系,先引进来便可。

    接着,咱们解释下面的代码。咱们先要定义一些具体的参数,好比分类的个数(outputNum),mini-batch的数量(batchSize)等等,具体在图中已经作了注释。须要说明的是MnistDataSetIterator这个迭代器类。这个类实际上是一个读取二进制Mnist数据集的high-level的封装。经过debug咱们能够发现,其中包括从网络中下载Mnist数据集,读取数据和标注,再构建迭代器的过程。在源码中,默认将下载的文件放在系统的user.home目录下,具体每一个人不一样会有所不一样。因为我本身所处的环境网络不咋的,因此颇有可能在利用这种high-level的接口的时候,由于下载Mnist数据失败而抛出异常,最终没法训练。因此,你们能够先自行下载好这些数据,而后按照源码的要求,放到相应的目录下并根据源码正确命名文件,那这样就依然能够利用这种high-level的接口。具体须要参考的是MnistDataFetcher类中相关代码。github

int nChannels = 1;      //black & white picture, 3 if color image
        int outputNum = 10;     //number of classification
        int batchSize = 64;     //mini batch size for sgd
        int nEpochs = 10;       //total rounds of training
        int iterations = 1;     //number of iteration in each traning round
        int seed = 123;         //random seed for initialize weights

        log.info("Load data....");
        DataSetIterator mnistTrain = null;
        DataSetIterator mnistTest = null;

        mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
        mnistTest = new MnistDataSetIterator(batchSize, false, 12345);

 

当咱们正确读取数据后,咱们须要定义具体的神经网络结构,这里我用的是Lenet,Deeplearning4j的实现参考了官网(https://github.com/deeplearning4j/dl4j-examples)的例子。具体代码以下:算法

MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .regularization(true).l2(0.0005)
                .learningRate(0.01)//.biasLearningRate(0.02)
                //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
                        .nIn(nChannels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation("identity")
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        //Note that nIn need not be specified in later layers
                        .stride(1, 1)
                        .nOut(50)
                        .activation("identity")
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation("relu")
                        .nOut(500).build())
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation("softmax")
                        .build())
                .backprop(true).pretrain(false)
                .cnnInputSize(28, 28, 1);
        // The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel
        //new ConvolutionLayerSetup(builder,28,28,1);

        MultiLayerConfiguration conf = builder.build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();        
        model.setListeners(new ScoreIterationListener(1));         // a listener which can print loss function score after each iteration

能够发现,神经网络须要定义不少的超参数,学习率、正则化系数、卷积核的大小、激励函数等都是须要人为设定的。不一样的超参数,对结果的影响很大,其实后来发现,不少时间都花在数据处理和调参方面。毕竟本身设计网络的能力有限,通常都是参考大牛的论文,而后本身照葫芦画瓢地实现。这里实现的Lenet的结构是:卷积-->下采样-->卷积-->下采样-->全链接。和原论文的结构基本一致。卷积核的大小也是参考的原论文。具体细节可参考以前发的论文连接。这里咱们设置了一个Score的监听事件,主要是能够在训练的时候获取每一次权重更新后损失函数的收敛状况。后面一会有截图。编程

定义完网络结构以后,咱们就能够对以前读取的数据进行训练和分类准确性评估。先看下代码:bash

for( int i = 0; i < nEpochs; ++i ) {  
    model.fit(mnistTrain);  
    log.info("*** Completed epoch " + i + "***");  
  
    log.info("Evaluate model....");  
    Evaluation eval = new Evaluation(outputNum);  
    while(mnistTest.hasNext()){  
        DataSet ds = mnistTest.next();            
        INDArray output = model.output(ds.getFeatureMatrix(), false);  
        eval.eval(ds.getLabels(), output);  
    }  
    log.info(eval.stats());  
    mnistTest.reset();  
} 

    相信这部分是比较容易理解的。每训练完一轮后,咱们会对测试集合进行评估,而后打印出相似下面的结果。图中的上半部分是具体分类的统计,包括分对的和分错的图片数量均可以看获得。而后,咱们耐心等待一段时间,能够看到通过10轮训练的Lenet对于Mnist数据集的分类准确率达到99%以下:网络

Examples labeled as 0 classified by model as 0: 974 times
Examples labeled as 0 classified by model as 6: 2 times
Examples labeled as 0 classified by model as 7: 2 times
Examples labeled as 0 classified by model as 8: 1 times
Examples labeled as 0 classified by model as 9: 1 times
Examples labeled as 1 classified by model as 0: 1 times
Examples labeled as 1 classified by model as 1: 1128 times
Examples labeled as 1 classified by model as 2: 1 times
Examples labeled as 1 classified by model as 3: 2 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 2 times
Examples labeled as 2 classified by model as 2: 1026 times
Examples labeled as 2 classified by model as 4: 1 times
Examples labeled as 2 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 7: 3 times
Examples labeled as 2 classified by model as 8: 1 times
Examples labeled as 3 classified by model as 0: 1 times
Examples labeled as 3 classified by model as 1: 1 times
Examples labeled as 3 classified by model as 2: 1 times
Examples labeled as 3 classified by model as 3: 998 times
Examples labeled as 3 classified by model as 5: 3 times
Examples labeled as 3 classified by model as 7: 1 times
Examples labeled as 3 classified by model as 8: 4 times
Examples labeled as 3 classified by model as 9: 1 times
Examples labeled as 4 classified by model as 2: 1 times
Examples labeled as 4 classified by model as 4: 973 times
Examples labeled as 4 classified by model as 6: 2 times
Examples labeled as 4 classified by model as 7: 1 times
Examples labeled as 4 classified by model as 9: 5 times
Examples labeled as 5 classified by model as 0: 2 times
Examples labeled as 5 classified by model as 3: 4 times
Examples labeled as 5 classified by model as 5: 882 times
Examples labeled as 5 classified by model as 6: 1 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 8: 2 times
Examples labeled as 6 classified by model as 0: 4 times
Examples labeled as 6 classified by model as 1: 2 times
Examples labeled as 6 classified by model as 4: 1 times
Examples labeled as 6 classified by model as 5: 4 times
Examples labeled as 6 classified by model as 6: 945 times
Examples labeled as 6 classified by model as 8: 2 times
Examples labeled as 7 classified by model as 1: 5 times
Examples labeled as 7 classified by model as 2: 3 times
Examples labeled as 7 classified by model as 3: 1 times
Examples labeled as 7 classified by model as 7: 1016 times
Examples labeled as 7 classified by model as 8: 1 times
Examples labeled as 7 classified by model as 9: 2 times
Examples labeled as 8 classified by model as 0: 1 times
Examples labeled as 8 classified by model as 3: 1 times
Examples labeled as 8 classified by model as 5: 2 times
Examples labeled as 8 classified by model as 7: 2 times
Examples labeled as 8 classified by model as 8: 966 times
Examples labeled as 8 classified by model as 9: 2 times
Examples labeled as 9 classified by model as 3: 1 times
Examples labeled as 9 classified by model as 4: 2 times
Examples labeled as 9 classified by model as 5: 4 times
Examples labeled as 9 classified by model as 6: 1 times
Examples labeled as 9 classified by model as 7: 5 times
Examples labeled as 9 classified by model as 8: 3 times
Examples labeled as 9 classified by model as 9: 993 times


==========================Scores========================================
 Accuracy:        0.9901
 Precision:       0.99
 Recall:          0.99
 F1 Score:        0.99
========================================================================
[main] INFO cv.LenetMnistExample - ****************Example finished********************

    由于图传不上去,我就直接粘帖告终果。从中咱们看到最终的一个准确率,还有就是哪些图片是分类正确的,哪些是分类错误的。固然咱们能够经过增长训练的轮次还有调超参数来进一步优化,不过实际上这样的结果已经能够拿到生产上去用了。app

    总结一下。其实包括我本身在内的不少人都对深度学习不了解,记得当时看csdn上写的有关深度学习的博客的时候,都以为本身不可能达到那种水平。但其实,咱们都忽略了一点,深度学习自身再复杂,它也是一个算法模型,也是一种机器学习。虽然它比感知机、逻辑回归等模型复杂不少(其实逻辑回归可看做神经网络中的一个神经元,充当的是激励函数的做用,相似的激励函数不少,如tanh,relu等),但终究用它的目的依然是完成回归、分类、压缩数据等任务。因此第一步尝试仍是挺重要的。固然,咱们不可能从复杂的模型开始,一开始就跟上当下最流行的模型,因此就从Mnist识别的例子开始,找找感受。之后会写一些用Deeplearning4j在Spark的案例,也仍是从Mnist开始。分享的同时本身也复习一下。。。

相关文章
相关标签/搜索