1 神經網絡結構
1.0 保存*.pb模型
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from mpl_toolkits.mplot3d import Axes3D
# Ubuntu system font path
font = FontProperties(fname='/usr/share/fonts/truetype/arphic/ukai.ttc')
MODEL_SAVE_PATH = "./models"
MODEL_NAME_meta = "nn_model.ckpt"
MODEL_NAME_pb = "nn_model.pb"
LOG_DIR = "./logs/NNmergelog"
'''Simulation datas.'''
x_data = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
y_data = np.square(x_data) - 0.5*x_data + noise
'''Neural network structure.'''
input_size_1 = 1
output_size_1 = 10
input_size_2 = 10
output_size_2 = 1
def saved_model_pb():
'''Input layer.'''
with tf.name_scope("Input"):
xs = tf.placeholder(tf.float32, [None, 1], name='x')
ys = tf.placeholder(tf.float32, [None, 1], name='y')
'''Hidden layer.'''
with tf.name_scope("Layer1"):
weights_1 = tf.Variable(tf.random_normal([input_size_1, output_size_1]), name='weights_1')
biases_1 = tf.Variable(tf.zeros([1, output_size_1]), name='biases_1')
layer_1 = tf.nn.relu(tf.matmul(xs, weights_1) + biases_1)
tf.summary.histogram('weights_1', weights_1)
tf.summary.histogram('biases_1', biases_1)
tf.summary.histogram('layer_1', layer_1)
'''Ouptput Layer.'''
with tf.name_scope("Output"):
weights_2 = tf.Variable(tf.random_normal([input_size_2, output_size_2]), name='weights_2')
biases_2 = tf.Variable(tf.zeros([1, output_size_2]), name='biases_2')
outputs_2 = tf.matmul(layer_1, weights_2)
prediction = tf.add(outputs_2, biases_2, name="predictions")
tf.summary.histogram('weights_2', weights_2)
tf.summary.histogram('biases_2', biases_2)
tf.summary.histogram('prediction', prediction)
'''Loss function.'''
with tf.name_scope("Loss"):
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1]))
tf.summary.scalar('loss', loss)
tf.summary.histogram('loss', loss)
'''Optimizer the loss.'''
with tf.name_scope("Train_Step"):
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
'''Merge all the summary up used.'''
merged = tf.summary.merge_all()
'''Save Model.'''
with tf.Session() as sess:
'''Initializer varabiles and log defined in Tensorflow.'''
summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
init_op = tf.global_variables_initializer()
sess.run(init_op)
a = 0
for i in range(301):
'''Convert nodes to constant in models by name.'''
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["Input/x", "Input/y", "Output/predictions"])
'''Compute the nodes value and save the log file.'''
summary, train_step_value, loss_value, pre = sess.run([merged, train_step, loss, prediction], feed_dict={xs: x_data, ys: y_data})
if i % 50 == 0:
'''Output train effects in every 50 steps.'''
a += 1
w1 = sess.run(weights_1)
w2 = sess.run(weights_2)
print("Weights_1 :{}".format(w1))
print("weights_2 :{}".format(w2))
# print(a)
# loss_1 = sess.run(loss, feed_dict={xs: x_data, ys: y_data})
print("Loss :{}".format(loss_value))
print(prediction)
print(loss)
print(train_step_value)
'''Write the model parameters in specify files we are defined.'''
with tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, MODEL_NAME_pb), mode="wb") as f:
f.write(constant_graph.SerializeToString())
'''Summary total logs in files.'''
summary_writer.add_summary(summary, i)
summary_writer.close()
1.2 載入*.pb模型
def load_pb_model():
with tf.Session() as sess:
'''Input data for evaluate the model.'''
x_data = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
y_data = np.square(x_data) - 0.5*x_data + noise
'''Load model from *.pb'''
with gfile.FastGFile("./models/nn_model.pb", "rb") as f:
new_graph = tf.GraphDef()
new_graph.ParseFromString(f.read())
tf.import_graph_def(new_graph, name='')
'''
Get default graph structure, which operation must be
after loaded the modelself.
'''
g = tf.get_default_graph()
'''
Get tensor by name in graph we defined,
we use the variable scope or name scope,
thus we need append the name prefix before load node names.
'''
pre = g.get_tensor_by_name("Output/predictions:0")
x = g.get_tensor_by_name("Input/x:0")
'''Compute the prediction value by loading the trained model.'''
pre = sess.run(pre, feed_dict={x: x_data})
plt.figure(figsize=(6, 6))
plt.plot(x_data, pre, label="預測結果")
plt.grid()
plt.xlabel("x軸", fontproperties=font)
plt.ylabel("y軸", fontproperties=font)
plt.scatter(x_data, y_data, s=10, c="r", marker="*", label="實際值")
plt.legend(prop=font)
'''Save and show image.'''
plt.savefig("./images/pb_load.png", format="png")
plt.show()
2 結果
2.1 訓練結果
權重與偏置項結果.
Weights_1 :[[-1.7107134 0.5941573 0.37450954 0.53004044 0.3793792 0.9144222
-1.5825071 -0.6608934 -0.96931577 0.5307749 ]]
weights_2 :[[ 0.31370506]
[-1.4543793 ]
[-1.9223864 ]
[ 0.14437917]
[-1.1137098 ]
[ 0.05373428]
[ 0.6884544 ]
[-0.01735083]
[ 0.03221066]
[ 1.2156694 ]]
Loss :0.004902012180536985
Tensor("Output/add:0", shape=(?, 1), dtype=float32)
Tensor("Loss/Mean:0", shape=(), dtype=float32)
None
2.2 載入模型驗證結果
3 java調用pb模型
3.1 pom.xml
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.6</version>
</dependency>
3.2 控制層
package com.sb.controller;
import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.util.ResourceUtils;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.File;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@CrossOrigin(origins="*", maxAge=3600)
@RestController
@RequestMapping("/api/ai")
public class AIController{
static Logger logger = LoggerFactory.getLogger(AIController.class);
@RequestMapping(value="pre", method=RequestMethod.POST)
public Float predictionTest(@RequestBody Map datas) throws Exception{
float input = Float.parseFloat(datas.get("input").toString());
// float[][] x = new float[1][1];
// x[0] = new float[]{1.0f};
float[][] x = {{input}};
logger.info("input:{}", x);
File file = ResourceUtils.getFile("classpath:model/nn_model.pb");
try(Graph graph = new Graph()){
byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(file));
graph.importGraphDef(graphBytes);
try(Session session = new Session(graph)){
// python 執行調用模型計算:pre = sess.run(pre, feed_dict={x: x_data})
// pre = session.runner()
// .feed("Input/x:0", Tensor.create(x))
// .fetch("Output/predictions:0").run().get(0).floatValue();
// 獲取模型生成的Tensor,使用floatValue報錯:java.lang.Illeagal: Tensor is not a scalar
Tensor prediction = session.runner()
.feed("Input/x:0", Tensor.create(x))
.fetch("Output/predictions:0").run().get(0);
float[][] preOutput = (float[][])prediction.copyTo(new float[1][1]);
return preOutput[0][0];
}
}catch(Exception e){
e.printStackTrace();
}
return input;
}
}
4 總結
- *.pb模型文件具有語言獨立性,可獨立運行,封閉的序列化格式,可使用任何語言解析。
- *.pb模型文件中變量是固定的(const)即模型中的變量值固定存儲。
- *.pb模型文件使用過程中不會重新“學習”,即模型參數不變,保證了模型的穩定性。
- *.pb模型文件實現了龐大模型的瘦身,即該格式的模型尺寸較小主要用於移動端。
基礎閱讀:
[1](一)Tensorflow搭建神經網絡
[2](二)Tensorflow神經網絡保存模型(持久化)
[3](三)Tensorflow神經網絡之模型載入及遷移學習
[參考文獻]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/gfile/FastGFile
[2]https://blog.csdn.net/fu6543210/article/details/80343345
[3]https://blog.csdn.net/wshzd/article/details/88840792