原文地址:https://cloud.tencent.com/developer/article/1370372
這是一篇對手冊性質的文章,如果你剛好從事AI開發,可以參考這文章來進行模型轉換。
Keras轉TFLite需要三個過程,
- Keras 轉 Tensorflow
- 固化 Tensorflow 網絡到 PB(Protocol Buffer)
- PB 轉 TFLite
Keras 網絡構成
Keras網絡有一個文件(正常情況)
- *.h5 它是HDF5格式文件,同時保存了網絡結構和網絡參數。
Tensorflow 網絡的構成
Tensorflow 常見的描述網絡結構文件是 ckpt,它有兩個文件構成
- model.ckpt
- model.ckpt.meta 新版本的 Tensorflow 的 Saver 會默認使用新格式保存,新格式的文件是這幾個
- model.ckpt.data-00000-of-00001
- model.ckpt.index
- model.ckpt.meta Tensorflow自從開源之後就經常有改動,目前還不確定新格式的三個文件是什麼作用跟含義。 就暫時以最穩定的老版本格式來解釋。
- model.ckpt 這個文件記錄了神經網絡上節點的權重信息,也就是節點上 wx+b 的取值。
- model.ckpt.meta 這個文件主要記錄了圖結構,也就是神經網絡的節點結構。
一個完整的神經網絡由這兩部分構成,Tensorflow 在保存時除了這兩個文件還會在目錄下自動生成 checkpoint, checkpoint的內容如下,它只記錄了目錄下有哪些網絡。
model_checkpoint_path: "squeezenet_model.ckpt" all_model_checkpoint_paths: "squeezenet_model.ckpt"
Keras 轉 Tensorflow
轉換過程需要先把網絡結構和權重加載到model對象, 然後用 tf.train.Saver 來保存爲 ckpt 文件。
目前代碼是以V1爲基礎的,指定Saver版本可以在構建Saver的時候指定參數 saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) saver.save(K.get_session(), './squeezenet_model.ckpt')
CKPT freeze 到 PB
ckpt的網絡結構和權重還是分開的 需要先固化到PB,才能繼續轉成 tflite。
Tensorflow 提供了python腳本用來固化,位置在
/usr/local/lib/python3.6/site-packages/tensorflow/python/tools/freeze_graph.py
對於固化的過程需要關注這幾個參數
- input_meta_graph: meta 文件,也就是節點結構
- input_checkpoint: ckpt 文件,保存權重
- output_graph: 輸出PB文件的名稱
- output_node_names: 網絡輸出節點
- input_binary: 輸入文件是否爲二進制 下面的命令直接給出瞭如何轉換,對於幾個參數的意義比較難理解的是倒數第二個,文章後面再給出對它的解釋。
python3 freeze_graph.py \ --input_meta_graph=model.ckpt.meta \ --input_checkpoint=model.ckpt \ --output_graph=model.pb \ --output_node_names="final_result" \ --input_binary=true
PB 到 Tensorflow Lite
Tensorflow 提供了 TOCO 工具用來做轉換, 必填的參數有下面這些,
toco --graph_def_file=squeezenet_model.pb \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --output_file=model.tflite \ --inference_type=FLOAT \ --input_type=FLOAT \ --input_arrays=input \ --output_arrays=final_result \ --input_sahpes=1,227,227,3
參數中需要解釋的有這幾個, --input_shapes: 輸入數據的維度,跟你的網絡輸入有關。比如1,227,227,3,代表的是1個227*227的3通道圖片。 --output_arrays 和 --input_arrays: 這兩個參數跟網絡的輸入輸出有關。而 output_arrays 跟轉換成 PB 時的參數 --output_node_names 是一樣的。 也就是說這兩個參數必須在查看網絡之後才能確定 下面給出如何查看網絡的方法
查看PB網絡結構
在tensorflow包下面,跟freeze_graph.py同個目錄下有另一個腳本
import_pb_to_tensorboard.py
它接受一個protobuf文件作爲輸入,並輸出log到指定路徑。之後可以就用tensorboard查看log文件了。 tensorboard是一個把網絡視圖話的工具,可以在瀏覽器上直接查看網絡結構。 運行
python3 import_pb_to_tensorboard.py --model_dir model.pb --log_dir board/
如果環境沒問題的話會在board/目錄下生產 local文件, 你會在終端看到tensorflow的提示,
Model Imported. Visualize by running: tensorboard --logdir=board/
按提示執行tensorboard,就可以在瀏覽器中通過 localhost:6006 查看網絡結構了。 需要關注的是網絡的輸入和輸出節點的命名, 而它的命名就是上面幾個步驟中我們需要的參數名了。