自己在將tensorflow模型移動端部署的時候(使用 tensorflow lite),踩了很多坑,查了很多資料,現在做個記錄,所有參考資料在文章最後 參考 處列出。
tensorflow lite是TensorFlow Lite 是 Google I/O 2017 大會上的其中一個重要宣佈,有了TensorFlow Lite,應用開發者可以在移動設備上部署人工智能。
tensorflow lite 【github】
基本思路:
- 在pc端進行 Tensorflow 模型訓練,保存訓練模型
- 使用 工具將該模型轉換爲 Tensorflow lite 模型
- 在Android上使用
tensorflow模型持久化
在tensorflow中進行模型訓練,得到適合自己項目的模型。Tensorflow 模型訓練好之後會生成三個文件:
- model.ckpt.meta :保存Tensorflow計算圖結構,可以理解爲神經網絡的網絡結構
- model.ckpt :保存Tensorflow程序中每一個變量的取值,變量是模型中可訓練的部分
- checkpoint :保存一個目錄下所有模型文件列表
# 使用tf.train.write_graph導出GraphDef文件
tf.train.write_graph(sess.graph_def, "./", "mz_graph.pb", as_text=False)
# 使用tf.train.save導出checkpoint文件
saver.save(sess, model_path)
生成的模型文件如下圖所示:
bazel編譯需要的工具
Tensoflow使用的編譯工具是 bazel,谷歌開源的自動化構建工具。【bazel傳送門】
安裝bazel,用來編譯 tensorflow 轉 tflite 時用到的幾個工具,freeze、toco、summarize_graph(具體作用下面說),這些工具都在 tensorflow(從github上clone) 中,按下面命令進行編譯(在 tensorflow目錄下進行):
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/contrib/lite/toto:toto
Bazel build tensorflow/tools/graph_transforms:summarize_graph (查看模型結構,找出輸入輸出)
模型轉換
將訓練好的tf模型,進行freeze、toco操作,freeze主要是將 tensorflow模型持久化 中生成的文件進行合併,得到一個變量值和運算圖模型相結合的文件,是將變量值固定在圖中的操作。如上圖,這步生成 mz_freezegraph.pb .
summarize_graph
該命令查看整個Tensorflow模型概況,使用命令如下,運行之後,得到自己整個網絡結構,從中可以找到自己模型的輸入輸出,如下圖(模型比較亂。。。)
# “--in_graph=” 後面是模型存儲的位置
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=../mz_graph.pb
freeze_graph
該命令是 Tensorflow模型固化,將Tensorflow模型和計算圖上變量的值合二爲一,方便直接轉換 Tensorflow lite 模型。
bazel-bin/tensorflow/python/tools/freeze_graph\
--input_graph=/tmp/mobilenet_v1_224.pb \
--input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
--input_binary=true \
--output_graph=/tmp/frozen_mobilenet_v1_224.pb \
--output_node_names=MobileNet/Predictions/Reshape_1
- input_graph :Tensorflow 模型結構文件
- input_checkpoint :Tensorflow 模型 ckpt 文件
- output_graph :輸出的freeze文件
- output_node_names :模型輸出節點名字,使用 summarize_graph 查看 ,可以在 Tensorflow 網絡訓練時進行命名
toco
固化模型到 tflite 模型轉化
toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/tmp/mobilenet_v1_1.0_224.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=input \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,224,224,3
- input_file : freeze 之後的 Tensorflow 模型文件
- output_file :轉換好的 Tensorflow lite 模型,擴展名爲 .tflite
- output_arrays :仍然是Tensorflow 模型的輸出
- input_shapes :輸入圖片的維度
部署Android
1、安裝 官方GitHub進行Android軟件搭建 Tensorflow lite 【Github】
2、工程中有 Float 和 Quantized 兩個模式可選,如下圖,這裏使用Float,Quantized需要先量化模型,在進行 tflite 模型轉換。
3、將生成的 .tflite 文件和 對應的 labels.txt 文件放入Android工程的 assets 文件中。
4、運行即可。