整合DL4J訓練模型與Web工程 頂 原 薦

一、前言

    上一篇博客《有趣的卷積神經網絡》介紹如何基於deeplearning4j對手寫數字識別進行訓練,對於整個訓練集只訓練了一次,正確率是0.9897,隨着迭代次數的增加,網絡模型將更加逼近訓練集,下面是對訓練集迭代十次的評估結果,總之迭代次數的增加會更加逼近模型(注:增加迭代次數有時也會發生過擬合,有時候也並非很奏效,具體情況具體分析)。

 Accuracy:        0.9919
 Precision:       0.9919
 Recall:          0.9918
 F1 Score:        0.9918

二、導讀

    1、web環境搭建

    2、基於canvas構建前端畫圖界面

    3、整合dl4j訓練模型

三、web環境搭建

    1、eclipse  new一個Maven project ,填好maven座標,packaging選war

<groupId>org.dl4j</groupId>
<artifactId>digitalrecognition</artifactId>
<version>0.0.1-SNAPSHOT</version>
<packaging>war</packaging>

    2、配置Jar包依賴,由於servlet-api一般由web容器提供,所以scope爲provided,這樣不會被打入war包裏。

<dependencies>
		<dependency>
			<groupId>org.springframework</groupId>
			<artifactId>spring-webmvc</artifactId>
			<version>4.3.4.RELEASE</version>
		</dependency>
		<dependency>
			<groupId>javax.servlet</groupId>
			<artifactId>servlet-api</artifactId>
			<version>2.5</version>
			<scope>provided</scope>
		</dependency>
		<dependency>
			<groupId>com.fasterxml.jackson.core</groupId>
			<artifactId>jackson-core</artifactId>
			<version>2.5.3</version>
		</dependency>

		<dependency>
			<groupId>com.fasterxml.jackson.core</groupId>
			<artifactId>jackson-annotations</artifactId>
			<version>2.5.3</version>
		</dependency>

		<dependency>
			<groupId>com.fasterxml.jackson.core</groupId>
			<artifactId>jackson-databind</artifactId>
			<version>2.5.3</version>
		</dependency>
		<dependency>
			<groupId>commons-fileupload</groupId>
			<artifactId>commons-fileupload</artifactId>
			<version>1.3.1</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-core</artifactId>
			<version>0.9.1</version>
		</dependency>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-native-platform</artifactId>
			<version>0.9.1</version>
		</dependency>
	</dependencies>

    3、爲了開發方便,不用把web工程部署到外置web容器,所以在開發時用mavan tomcat插件是比較方便的。運行時mvn tomcat7:run即可

<build>
		<plugins>
			<plugin>
				<groupId>org.apache.tomcat.maven</groupId>
				<artifactId>tomcat7-maven-plugin</artifactId>
				<version>2.2</version>
				<configuration>
					<uriEncoding>UTF-8</uriEncoding>
					<path>/</path>
					<port>8080</port>
					<protocol>org.apache.coyote.http11.Http11NioProtocol</protocol>
					<maxThreads>1000</maxThreads>
					<minSpareThreads>100</minSpareThreads>
				</configuration>
			</plugin>
		</plugins>
	</build>

    4、web常規配置web.xml,filter、servlet、listener這裏就略去了。

四、前端canvas畫圖實現

    1、html元素、css

<style type="text/css">
body {
	padding: 0;
	margin: 0;
	background: white;
}

#canvas {
	margin: 100px 0 0 300px;
}

#canvas>span {
	color: white;
	font-size: 14px;
}

#result {
	margin: 0px 0 0 300px;
}
</style>
<html>
<head>
<title>數字識別</title>
</head>
<body>
	<canvas id="canvas" width="280" height="280"></canvas>
	<button onclick="predict()">預測</button>
	<div id="result">
		識別結果:<font size="18" id="digit"></font>
	</div>
</body>
</html>

    2、js代碼實現在canvas畫布連線操作,並將圖片轉化爲base64格式,ajax發送給後端,這裏畫布的大小是280px,所以圖片到了後端,需要縮小至十分之一。

<script src="/js/jquery-3.2.1.min.js"></script>
<script type="text/javascript">
	/*獲取繪製環境*/
	var canvas = $('#canvas')[0].getContext('2d');
	canvas.strokeStyle = "white";//線條的顏色
	canvas.lineWidth = 10;//線條粗細
	canvas.fillStyle = 'black'
	canvas.fillRect(0, 0, 280, 280);
	$('#canvas').on('mousedown', function() {
		/*開始繪製*/
		canvas.beginPath();
		/*設置動畫繪製起點座標*/
		canvas.moveTo(event.pageX - 300, event.pageY - 100);
		$('#canvas').on('mousemove', function() {
			/*設置下一個點座標*/
			canvas.lineTo(event.pageX - 300, event.pageY - 100);
			/*畫線*/
			canvas.stroke();
		});
	}).on('mouseup', function() {
		$('#canvas').off('mousemove');
	});
	function predict() {
		var img = $('#canvas')[0].toDataURL("image/png");
		$.ajax({
			url : "/digitalRecognition/predict",
			type : "post",
			data : {
				"img" : img.substring(img.indexOf(",") + 1)
			},
			success : function(response) {
				$("#digit").html(response);
			},
			error : function() {
			}
		});
	}
</script>

    整體呈現的界面如下,可以畫圖。

五、後端java代碼

@RequestMapping("/digitalRecognition")
@Controller
public class DigitalRecognitionController implements InitializingBean {
	private MultiLayerNetwork net;

	@ResponseBody
	@RequestMapping("/predict")
	public int predict(@RequestParam(value = "img") String img) throws Exception {
		String imagePath= generateImage(img);//將base64圖片轉化爲png圖片
		imagePath= zoomImage(imagePath);//將圖片縮小至28*28
		DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
		ImageRecordReader testRR = new ImageRecordReader(28, 28, 1);
		File testData = new File(imagePath);
		FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS);
		testRR.initialize(testSplit);
		DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, 1);
		testIter.setPreProcessor(scaler);
		INDArray array = testIter.next().getFeatureMatrix();
		return net.predict(array)[0];
	}

	private String generateImage(String img) {
		BASE64Decoder decoder = new BASE64Decoder();
		String filePath = WebConstant.WEB_ROOT + "upload/"+UUID.randomUUID().toString()+".png";
		try {
			byte[] b = decoder.decodeBuffer(img);
			for (int i = 0; i < b.length; ++i) {
				if (b[i] < 0) {
					b[i] += 256;
				}
			}
			OutputStream out = new FileOutputStream(filePath);
			out.write(b);
			out.flush();
			out.close();
		} catch (Exception e) {
			e.printStackTrace();
		}
		return filePath;
	}
	
	private String zoomImage(String filePath){
		String imagePath=WebConstant.WEB_ROOT + "upload/"+UUID.randomUUID().toString()+".png";
		try {
			BufferedImage bufferedImage = ImageIO.read(new File(filePath));
			Image image = bufferedImage.getScaledInstance(28, 28, Image.SCALE_SMOOTH);
			BufferedImage tag = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
			Graphics g = tag.getGraphics();
			g.drawImage(image, 0, 0, null); // 繪製處理後的圖
			g.dispose();
			ImageIO.write(tag, "png",new File(imagePath));
		} catch (Exception e) {
			e.printStackTrace();
		}
		return imagePath;
	}
	

	@Override
	public void afterPropertiesSet() throws Exception {
		net = ModelSerializer.restoreMultiLayerNetwork(new File(WebConstant.WEB_ROOT + "model/minist-model.zip"));
	}

}

    代碼說明:

    1、InitializingBean是spring bean生命週期中的一個環節,spring構建bean的過程中會執行afterPropertiesSet方法,這裏用這個方法來加載已經定型的網絡。

      2、generateImage是用來將前端傳過來的base64串轉化爲png格式。

      3、zoomImage方法將前端的280*280縮小至28*28和訓練數據一致,並存到webroot的upload目錄下。

     4、predict進行預測,將轉化好的28*28的圖片讀取出來,張量化,把像素點的值壓縮至0到1,預測,最後結果是一個數組,由於只有一張圖片,取數組的第一個元素即可。

六、測試,mvn tomcat7:run,瀏覽器訪問http://localhost:8080即可玩手寫數字識別了

    

           

    測試結果馬馬虎虎,大體上實現了基本功能。

    git地址:https://gitee.com/lxkm/dl4j-demo/tree/master/digitalrecognition

    快樂源於分享。

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章