TensorFlow實戰———模型持久化

TensorFlow實戰———模型持久化

爲了讓訓練結果可以複用,需要將訓練得到的神經網絡模型持久化。

持久化代碼實現

TensorFlow提供了一個非常簡單的API來保存和還原一個神經網絡模型,這個API就是tf.train.Saver類。
““python
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[1], name=”v1”)
v2 = tf.Variable(tf.constant(2.0, shape=[2], name=”v2”)
result = v1 + v2

init_op = tf.initialize_all_variables()
saver = tf.train.Saver()#聲明tf.train.Saver類用於保存模型
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, “/path/to/model/model.ckpt”)

TensorFlow模型一般會存在後綴爲.ckpt的文件中。雖然上面的程序只指定了一個文件路徑,但是在這個文件目錄下會出現三個文件,這是因爲TensorFlow會將計算圖的結構和圖上參數取值分開保存。
上面這段代碼會生成的第一個文件爲model.ckpt.meta,它保存了TensorFlow計算圖的結構。第二個文件爲model.ckpt,這個文件中保存了TensorFlow程序中每一個變量的取值。最後一個文件爲checkpoint文件,這個文件中保存了一個目錄下所有的模型文件列表。
```python
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[1], name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[2], name="v2")
result = v1 + v2
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, "path/to/model/model.ckpt")
    print sess.run(result)




<div class="se-preview-section-delimiter"></div>

與上面保存模型的代碼相比,兩段代碼唯一不同的是,在加載模型的代碼中沒有運行變量的初始化過程,而是將變量的值通過已經保存的模型加載進來。如果不希望重複定義圖上的運算,也可以直接加載已經持久化的圖。

import tensorflow as tf
saver = tf.train.import_meta_graph("/path/to/model/model.ckpt/model.ckpt.meta")
with tf.session() as sess:
    saver.restore(sess, "/path/to/model/model.ckpt")
    print sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))




<div class="se-preview-section-delimiter"></div>

爲了保存或者加載部分變量,在聲明tf.train.Saver類時可以提供一個列表來指定需要保存或者加載的變量。除了可以選取需要被加載的變量,tf.train.Saver類也支持在保存或者加載時給變量重命名。

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
v2 = tf.variabe(tf.constant(2.0, shape=[1]), name="other-v2")




<div class="se-preview-section-delimiter"></div>

#如果直接使用tf.train.Saver來加載模型會報變量找不到的錯誤




<div class="se-preview-section-delimiter"></div>

#使用一個字典來重命名變量就可以加載原來的模型了。這個字典指定了原來名稱爲v1的變來那個現在加載到變量v1中(名稱爲other-v1),v2同理
saver = tf.train.Saver({"v1":v1, "v2":v2})




<div class="se-preview-section-delimiter"></div>

在這個程序中,對變量v1和v2的名稱進行了修改。如果直接通過tf.train.Saver默認的構造函數來加載保存的模型,那麼程序會報變量找不到的錯誤。因爲保存時候變量的名稱和加載時變量的名稱不一致。爲了解決這個問題,TensorFlow可以通過字典將模型保存時的變量名和需要加載的變量聯繫起來。
那麼,爲什麼要使用這樣的重命名機制呢?這樣做的主要目的之一是方便使用變量的滑動平均值。
在TensorFlow中,每一個變量的滑動平均值是通過影子變量維護的,所以要獲取變量的滑動平均值實際上是獲取這個影子變量的取值。如果在加載模型時直接將影子變量映射到變量自身,那麼在使用訓練好的模型時就不需要再調用函數來獲取變量的滑動平均值了。這樣大大方便了滑動平均模型的使用。

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")




<div class="se-preview-section-delimiter"></div>

#在沒有聲明滑動平均模型時只有一個變量v,所以下面的語句只會輸出“v:0"
for variable in tf.all_variables():
    print variables.name
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_vaerages_op = ema.apply(tf.all_variables())




<div class="se-preview-section-delimiter"></div>

#在聲明滑動平均模型之後,TensorFlow會自動生成一個影子變量v/ExponentialMoving Average。於是下面的語句會輸出”v:0“和”v/ExponentialMovingAverage:0“
for variables in tf.all_variables():
    print variables.name

saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.initialize_all_varialbes()
    sess.run(init_op)
    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    #保存時,TensorFlow會將v:0和v/ExponentialMovingAverage:0兩個變量都存下來。
    saver.save(sess, "/path/to/model/model.ckpt")
    print sess.run([v, ema.average(v)])




<div class="se-preview-section-delimiter"></div>

以下代碼給出瞭如何通過變量重命名直接讀取變量的滑動平均值。

v = tf.Variable(0, dtype=tf.float32, name="v")




<div class="se-preview-section-delimiter"></div>

#通過變量重命名將原來變量v的滑動平均值直接賦值給v
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
with tf.Session() as sess:
    saver.restore(sess, "/path/to/model/model.ckpt")
    print sess.run(v)#輸出0.099999905,這個值就是原來模型中變量v的滑動平均值




<div class="se-preview-section-delimiter"></div>

爲了方便加載時重命名滑動平均變量,tf.train.ExponentialMovingAverage類提供了variables_to_restore函數來生成tf.train.Saver類所需要的變量重命名字典。

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)




<div class="se-preview-section-delimiter"></div>

#通過使用variables_to_restore函數可以直接生成上面代碼中提供的字典




<div class="se-preview-section-delimiter"></div>

#{"v/ExponentialMovingAverage":v}




<div class="se-preview-section-delimiter"></div>

#以下代碼輸出:




<div class="se-preview-section-delimiter"></div>

#{'v/ExponentialMovingAverage':<tensorflow.python.ops.varialbes.Variable object at 0x7ff6454ddc10>}
print ema.variables_to_store()
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
    saver.restore(sess, "/path/to/model/model.ckpt")
    print sess.run(v)




<div class="se-preview-section-delimiter"></div>

TensorFlow提供了convert_varialbes_to_constants函數,通過這個函數可以將計算圖中的變量及其取值通過常量的方式保存,這樣整個TensorFlow計算圖可以統一存放在一個文件中。

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape[1]), name="v2")
result = v1 + v2
init_op = tf.initialize_all_varialbes()
with tf.Session() as sess:
    sess.run(init_op)
    #導出當前計算圖的GraphDef部分,只需要這一部分就可以完成從輸入層到輸出層的計算過程
    graph_def = tf.get_default_graph().as_graph_def()
    #將途中的變量及其取值轉化爲常量,同時將圖中不必要的節點去掉。在持久化原理中,其實一些系統運算也會被轉化爲計算圖中的節點(比如變量初始化)。如果只關心程序中定義的某些計算時,和這些計算無關的節點就沒有必要導出並保存了。在下面一行代碼中,最後一個參數['add']給出了需要保存的節點名稱。add節點是上面定義的兩個變量相加的操作。注意這裏給出的是計算節點的名稱,所以沒有後面的:0
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
    #將導出的模型存入文件
    with tf.gfile.GFile("/path/to/model/combined_model.pb", "wb") as f:
        f.write(output_graph_def.SerializeToString())




<div class="se-preview-section-delimiter"></div>

通過下面的程序可以直接計算定義的加法運算的結果。

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = "/path/to/model/combined_model.pb"
    #讀取保存的模型文件,並將文件解析成對應的GraphDef Protocol Buffer
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    #將graph_def中保存的圖加載到當前的圖中,return_elements=["add:0"]給出了返回的張量的名稱。在保存的時候給出的是計算節點的名稱,所以爲“add”,在加載的時候給出的是張量的名稱,所以是add:0
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print sess.run(result)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章