Tensorflow 模型轉 tflite ,在安卓端使用

自己在將tensorflow模型移動端部署的時候(使用 tensorflow lite),踩了很多坑,查了很多資料,現在做個記錄,所有參考資料在文章最後 參考 處列出。

tensorflow lite是TensorFlow Lite 是 Google I/O 2017 大會上的其中一個重要宣佈,有了TensorFlow Lite,應用開發者可以在移動設備上部署人工智能。
tensorflow lite 【github】

這裏寫圖片描述

基本思路:

  1. 在pc端進行 Tensorflow 模型訓練,保存訓練模型
  2. 使用 工具將該模型轉換爲 Tensorflow lite 模型
  3. 在Android上使用

tensorflow模型持久化

在tensorflow中進行模型訓練,得到適合自己項目的模型。Tensorflow 模型訓練好之後會生成三個文件:

  • model.ckpt.meta :保存Tensorflow計算圖結構,可以理解爲神經網絡的網絡結構
  • model.ckpt :保存Tensorflow程序中每一個變量的取值,變量是模型中可訓練的部分
  • checkpoint :保存一個目錄下所有模型文件列表
# 使用tf.train.write_graph導出GraphDef文件
tf.train.write_graph(sess.graph_def, "./", "mz_graph.pb", as_text=False)
# 使用tf.train.save導出checkpoint文件
saver.save(sess, model_path)

生成的模型文件如下圖所示:
這裏寫圖片描述

bazel編譯需要的工具

Tensoflow使用的編譯工具是 bazel,谷歌開源的自動化構建工具。【bazel傳送門】
安裝bazel,用來編譯 tensorflow 轉 tflite 時用到的幾個工具,freeze、toco、summarize_graph(具體作用下面說),這些工具都在 tensorflow(從github上clone) 中,按下面命令進行編譯(在 tensorflow目錄下進行):

bazel build tensorflow/python/tools:freeze_graph

bazel build tensorflow/contrib/lite/toto:toto

Bazel build tensorflow/tools/graph_transforms:summarize_graph  (查看模型結構,找出輸入輸出)

模型轉換

將訓練好的tf模型,進行freeze、toco操作,freeze主要是將 tensorflow模型持久化 中生成的文件進行合併,得到一個變量值和運算圖模型相結合的文件,是將變量值固定在圖中的操作。如上圖,這步生成 mz_freezegraph.pb .

summarize_graph

該命令查看整個Tensorflow模型概況,使用命令如下,運行之後,得到自己整個網絡結構,從中可以找到自己模型的輸入輸出,如下圖(模型比較亂。。。)

# --in_graph=” 後面是模型存儲的位置
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=../mz_graph.pb

這裏寫圖片描述

freeze_graph

該命令是 Tensorflow模型固化,將Tensorflow模型和計算圖上變量的值合二爲一,方便直接轉換 Tensorflow lite 模型。

    bazel-bin/tensorflow/python/tools/freeze_graph\
        --input_graph=/tmp/mobilenet_v1_224.pb \
        --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
        --input_binary=true \
        --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
        --output_node_names=MobileNet/Predictions/Reshape_1
  • input_graph :Tensorflow 模型結構文件
  • input_checkpoint :Tensorflow 模型 ckpt 文件
  • output_graph :輸出的freeze文件
  • output_node_names :模型輸出節點名字,使用 summarize_graph 查看 ,可以在 Tensorflow 網絡訓練時進行命名

這裏寫圖片描述

toco

固化模型到 tflite 模型轉化

toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
      --input_format=TENSORFLOW_GRAPHDEF \
      --output_format=TFLITE \
      --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
      --inference_type=FLOAT \
      --input_type=FLOAT \
      --input_arrays=input \
      --output_arrays=MobilenetV1/Predictions/Reshape_1 \
      --input_shapes=1,224,224,3
  • input_file : freeze 之後的 Tensorflow 模型文件
  • output_file :轉換好的 Tensorflow lite 模型,擴展名爲 .tflite
  • output_arrays :仍然是Tensorflow 模型的輸出
  • input_shapes :輸入圖片的維度

這裏寫圖片描述

部署Android

1、安裝 官方GitHub進行Android軟件搭建 Tensorflow lite 【Github】
2、工程中有 FloatQuantized 兩個模式可選,如下圖,這裏使用Float,Quantized需要先量化模型,在進行 tflite 模型轉換。
3、將生成的 .tflite 文件和 對應的 labels.txt 文件放入Android工程的 assets 文件中。
4、運行即可。

這裏寫圖片描述

參考

  1. TensorFlow Lite學習筆記2:生成TFLite模型文件
  2. TensorFlow固化模型
  3. TensorFlow Lite模型生成以及bazel的安裝使用、出現的問題及解決方案整合
  4. Tensorflow Lite之編譯生成tflite文件
  5. tensorflow Lite的使用
  6. tensorflow模型量化
  7. 用 TensorFlow 壓縮神經網絡
  8. 在Android上使用TensorFlow Lite
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章