Keras神經網絡轉到Android可用的模型

這是一篇對手冊性質的文章,如果你剛好從事AI開發,可以參考這文章來進行模型轉換。

Keras轉TFLite需要三個過程,

  1. Keras 轉 Tensorflow
  2. 固化 Tensorflow 網絡到 PB(Protocol Buffer)
  3. PB 轉 TFLite

Keras 網絡構成

Keras網絡有一個文件(正常情況)

  • *.h5 它是HDF5格式文件,同時保存了網絡結構和網絡參數。

Tensorflow 網絡的構成

Tensorflow 常見的描述網絡結構文件是 ckpt,它有兩個文件構成

  • model.ckpt
  • model.ckpt.meta 新版本的 Tensorflow 的 Saver 會默認使用新格式保存,新格式的文件是這幾個
  • model.ckpt.data-00000-of-00001
  • model.ckpt.index
  • model.ckpt.meta Tensorflow自從開源之後就經常有改動,目前還不確定新格式的三個文件是什麼作用跟含義。 就暫時以最穩定的老版本格式來解釋。
  • model.ckpt 這個文件記錄了神經網絡上節點的權重信息,也就是節點上 wx+b 的取值。
  • model.ckpt.meta 這個文件主要記錄了圖結構,也就是神經網絡的節點結構。

一個完整的神經網絡由這兩部分構成,Tensorflow 在保存時除了這兩個文件還會在目錄下自動生成 checkpoint, checkpoint的內容如下,它只記錄了目錄下有哪些網絡。

model_checkpoint_path: "squeezenet_model.ckpt" all_model_checkpoint_paths: "squeezenet_model.ckpt"

Keras 轉 Tensorflow

轉換過程需要先把網絡結構和權重加載到model對象, 然後用 tf.train.Saver 來保存爲 ckpt 文件。

目前代碼是以V1爲基礎的,指定Saver版本可以在構建Saver的時候指定參數 saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) saver.save(K.get_session(), './squeezenet_model.ckpt')

CKPT freeze 到 PB

ckpt的網絡結構和權重還是分開的 需要先固化到PB,才能繼續轉成 tflite。

Tensorflow 提供了python腳本用來固化,位置在

/usr/local/lib/python3.6/site-packages/tensorflow/python/tools/freeze_graph.py

對於固化的過程需要關注這幾個參數

  • input_meta_graph: meta 文件,也就是節點結構
  • input_checkpoint: ckpt 文件,保存權重
  • output_graph: 輸出PB文件的名稱
  • output_node_names: 網絡輸出節點
  • input_binary: 輸入文件是否爲二進制 下面的命令直接給出瞭如何轉換,對於幾個參數的意義比較難理解的是倒數第二個,文章後面再給出對它的解釋。

python3 freeze_graph.py \ --input_meta_graph=model.ckpt.meta \ --input_checkpoint=model.ckpt \ --output_graph=model.pb \ --output_node_names="final_result" \ --input_binary=true

PB 到 Tensorflow Lite

Tensorflow 提供了 TOCO 工具用來做轉換, 必填的參數有下面這些,

toco --graph_def_file=squeezenet_model.pb \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --output_file=model.tflite \ --inference_type=FLOAT \ --input_type=FLOAT \ --input_arrays=input \ --output_arrays=final_result \ --input_sahpes=1,227,227,3

參數中需要解釋的有這幾個, --input_shapes: 輸入數據的維度,跟你的網絡輸入有關。比如1,227,227,3,代表的是1個227*227的3通道圖片。 --output_arrays 和 --input_arrays: 這兩個參數跟網絡的輸入輸出有關。而 output_arrays 跟轉換成 PB 時的參數 --output_node_names 是一樣的。 也就是說這兩個參數必須在查看網絡之後才能確定 下面給出如何查看網絡的方法

查看PB網絡結構

在tensorflow包下面,跟freeze_graph.py同個目錄下有另一個腳本

import_pb_to_tensorboard.py

它接受一個protobuf文件作爲輸入,並輸出log到指定路徑。之後可以就用tensorboard查看log文件了。 tensorboard是一個把網絡視圖話的工具,可以在瀏覽器上直接查看網絡結構。 運行

python3 import_pb_to_tensorboard.py --model_dir model.pb --log_dir board/

如果環境沒問題的話會在board/目錄下生產 local文件, 你會在終端看到tensorflow的提示,

Model Imported. Visualize by running: tensorboard --logdir=board/

按提示執行tensorboard,就可以在瀏覽器中通過 localhost:6006 查看網絡結構了。 需要關注的是網絡的輸入和輸出節點的命名, 而它的命名就是上面幾個步驟中我們需要的參數名了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章