Springboot2模塊系列:tensorflow(載入pb模型)

1 神經網絡結構

1.0 保存*.pb模型

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
import os

import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from mpl_toolkits.mplot3d import Axes3D
# Ubuntu system font path
font = FontProperties(fname='/usr/share/fonts/truetype/arphic/ukai.ttc')

MODEL_SAVE_PATH = "./models"
MODEL_NAME_meta = "nn_model.ckpt"
MODEL_NAME_pb = "nn_model.pb"
LOG_DIR = "./logs/NNmergelog"
'''Simulation datas.'''
x_data = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
y_data = np.square(x_data) - 0.5*x_data + noise

'''Neural network structure.'''
input_size_1 = 1
output_size_1 = 10

input_size_2 = 10
output_size_2 = 1
def saved_model_pb():
	'''Input layer.'''
	with tf.name_scope("Input"):
		xs = tf.placeholder(tf.float32, [None, 1], name='x')
		ys = tf.placeholder(tf.float32, [None, 1], name='y')
	'''Hidden layer.'''
	with tf.name_scope("Layer1"):
		weights_1 = tf.Variable(tf.random_normal([input_size_1, output_size_1]), name='weights_1')
		biases_1 = tf.Variable(tf.zeros([1, output_size_1]), name='biases_1')
		layer_1 = tf.nn.relu(tf.matmul(xs, weights_1) + biases_1)
		tf.summary.histogram('weights_1', weights_1)
		tf.summary.histogram('biases_1', biases_1)
		tf.summary.histogram('layer_1', layer_1)

	'''Ouptput Layer.'''
	with tf.name_scope("Output"):
		weights_2 = tf.Variable(tf.random_normal([input_size_2, output_size_2]), name='weights_2')
		biases_2 = tf.Variable(tf.zeros([1, output_size_2]), name='biases_2')
		outputs_2 = tf.matmul(layer_1, weights_2)
		prediction = tf.add(outputs_2, biases_2, name="predictions")
		tf.summary.histogram('weights_2', weights_2)
		tf.summary.histogram('biases_2', biases_2)
		tf.summary.histogram('prediction', prediction)

	'''Loss function.'''
	with tf.name_scope("Loss"):
		loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1]))
		tf.summary.scalar('loss', loss)
		tf.summary.histogram('loss', loss)

	'''Optimizer the loss.'''
	with tf.name_scope("Train_Step"):
		train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

	'''Merge all the summary up used.'''
	merged = tf.summary.merge_all()
	'''Save Model.'''
	with tf.Session() as sess:
		'''Initializer varabiles and log defined in Tensorflow.'''
		summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
		init_op = tf.global_variables_initializer()
		sess.run(init_op)
		a = 0
		for i in range(301):
			'''Convert nodes to constant in models by name.'''
			constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["Input/x", "Input/y", "Output/predictions"])
			'''Compute the nodes value and save the log file.'''
			summary, train_step_value, loss_value, pre = sess.run([merged, train_step, loss, prediction], feed_dict={xs: x_data, ys: y_data})
			if i % 50 == 0:
				'''Output train effects in every 50 steps.'''
				a += 1
				w1 = sess.run(weights_1)
				w2 = sess.run(weights_2)
				print("Weights_1 :{}".format(w1))
				print("weights_2 :{}".format(w2))
				# print(a)
				# loss_1 = sess.run(loss, feed_dict={xs: x_data, ys: y_data})
				print("Loss :{}".format(loss_value))
				print(prediction)
				print(loss)
				print(train_step_value)
			'''Write the model parameters in specify files we are defined.'''
			with tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, MODEL_NAME_pb), mode="wb") as f:
				f.write(constant_graph.SerializeToString())
			'''Summary total logs in files.'''
			summary_writer.add_summary(summary, i)
		summary_writer.close()

1.2 載入*.pb模型

def load_pb_model():
	with tf.Session() as sess:
		'''Input data for evaluate the model.'''
		x_data = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
		noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
		y_data = np.square(x_data) - 0.5*x_data + noise
		'''Load model from *.pb'''
		with gfile.FastGFile("./models/nn_model.pb", "rb") as f:
			new_graph = tf.GraphDef()
			new_graph.ParseFromString(f.read())
			tf.import_graph_def(new_graph, name='')
		'''
		Get default graph structure, which operation must be
		after loaded the modelself.
		'''
		g = tf.get_default_graph()
		'''
		Get tensor by name in graph we defined,
		we use the variable scope or name scope,
		thus we need append the name prefix before load node names.
		'''
		pre = g.get_tensor_by_name("Output/predictions:0")
		x = g.get_tensor_by_name("Input/x:0")
		'''Compute the prediction value by loading the trained model.'''
		pre = sess.run(pre, feed_dict={x: x_data})
		plt.figure(figsize=(6, 6))
		plt.plot(x_data, pre, label="預測結果")
		plt.grid()
		plt.xlabel("x軸", fontproperties=font)
		plt.ylabel("y軸", fontproperties=font)
		plt.scatter(x_data, y_data, s=10, c="r", marker="*", label="實際值")
		plt.legend(prop=font)
	'''Save and show image.'''
	plt.savefig("./images/pb_load.png", format="png")
	plt.show()

2 結果

2.1 訓練結果

權重與偏置項結果.

Weights_1 :[[-1.7107134   0.5941573   0.37450954  0.53004044  0.3793792   0.9144222
  -1.5825071  -0.6608934  -0.96931577  0.5307749 ]]
weights_2 :[[ 0.31370506]
 [-1.4543793 ]
 [-1.9223864 ]
 [ 0.14437917]
 [-1.1137098 ]
 [ 0.05373428]
 [ 0.6884544 ]
 [-0.01735083]
 [ 0.03221066]
 [ 1.2156694 ]]
Loss :0.004902012180536985
Tensor("Output/add:0", shape=(?, 1), dtype=float32)
Tensor("Loss/Mean:0", shape=(), dtype=float32)
None

2.2 載入模型驗證結果

在這裏插入圖片描述

圖2.1 預測值與理論值

3 java調用pb模型

3.1 pom.xml

<dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>tensorflow</artifactId>
      <version>1.5.0</version>
    </dependency>
    <dependency>
      <groupId>commons-io</groupId>
      <artifactId>commons-io</artifactId>
      <version>2.6</version>
    </dependency>

3.2 控制層

package com.sb.controller;

import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.util.ResourceUtils;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.File;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@CrossOrigin(origins="*", maxAge=3600)
@RestController
@RequestMapping("/api/ai")
public class AIController{
    static Logger logger = LoggerFactory.getLogger(AIController.class);
    @RequestMapping(value="pre", method=RequestMethod.POST) 
    public Float predictionTest(@RequestBody Map datas) throws Exception{

        float input = Float.parseFloat(datas.get("input").toString());
        // float[][] x = new float[1][1];
        // x[0] = new float[]{1.0f};
        float[][] x = {{input}};
        logger.info("input:{}", x);

        File file = ResourceUtils.getFile("classpath:model/nn_model.pb");
        try(Graph graph = new Graph()){
            byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(file));
            graph.importGraphDef(graphBytes);
            try(Session session = new Session(graph)){
                // python 執行調用模型計算:pre = sess.run(pre, feed_dict={x: x_data})
                // pre = session.runner()
                //                     .feed("Input/x:0", Tensor.create(x))
                //                     .fetch("Output/predictions:0").run().get(0).floatValue();
                // 獲取模型生成的Tensor,使用floatValue報錯:java.lang.Illeagal: Tensor is not a scalar
                Tensor prediction = session.runner()
                                .feed("Input/x:0", Tensor.create(x))
                                .fetch("Output/predictions:0").run().get(0);
                float[][] preOutput = (float[][])prediction.copyTo(new float[1][1]);
                return preOutput[0][0];

            }
            
        }catch(Exception e){
            e.printStackTrace();
        }
        return input;
        
    }
}

4 總結

  • *.pb模型文件具有語言獨立性,可獨立運行,封閉的序列化格式,可使用任何語言解析。
  • *.pb模型文件中變量是固定的(const)即模型中的變量值固定存儲。
  • *.pb模型文件使用過程中不會重新“學習”,即模型參數不變,保證了模型的穩定性。
  • *.pb模型文件實現了龐大模型的瘦身,即該格式的模型尺寸較小主要用於移動端。

基礎閱讀:
[1](一)Tensorflow搭建神經網絡
[2](二)Tensorflow神經網絡保存模型(持久化)
[3](三)Tensorflow神經網絡之模型載入及遷移學習


[參考文獻]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/gfile/FastGFile
[2]https://blog.csdn.net/fu6543210/article/details/80343345
[3]https://blog.csdn.net/wshzd/article/details/88840792


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章