TensorFlow中的模型持久化

本文參考《TensorFlow實戰Google深度學習框架》一書,總結了一些在TensorFlow在保存訓練好的模型過程中使用到的一些API

TF提供了tf.train.Saver類來保存和還原一個神經網絡模型

1.模型保存

模型保存的代碼如下所示:先聲明一個tf.train.Saver對象saver,然後使用saver.save進行保存,該函數的第二個參數是保存的路徑。注意保存的文件名後綴爲.ckpt。雖然只指定了一個文件路徑,但是最終會產生多個文件。
import tensorflow as tf

PATH = 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt'

'''part1:保存模型'''
v1 = tf.Variable(tf.constant([1]), name = 'v1')
v2 = tf.Variable(tf.constant([2]), name = 'v2')
result = v1 + v2
init_op = tf.initialize_all_variables()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(result))
    saver.save(sess, 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt')
產生的文件截圖

運行上述代碼,就可以將整個網絡結構和相關的數據都保存下來。

2.模型恢復

TF中可以使用saver.restore來恢復之前已經保存的模型,以下代碼給出了加載這個已經保存的模型的方法:
v1 = tf.Variable(tf.constant([1]), name = 'v1')
v2 = tf.Variable(tf.constant([2]), name = 'v2')
result = v1 + v2

saver = tf.train.Saver()
with tf.Session() as sess2:
    saver.restore(sess2, 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt')
    print(sess2.run(result))
需要注意的是,要使用這種方法恢復模型,則v1, v2和result都必須重新聲明,否則會報錯,而變量的初始化操作則可以不必運行,被換成了加載已經保存了的模型。此處的輸出爲[3]。在這個過程中只要name屬性保持'v1'和'v2',用於表示變量的標識符是可以改變的。如v1_,v2_。
如果不希望重複定義上面的變量v1, v2, result,則可以使用另外一種方法來加載已經保存的模型。
import tensorflow as tf
saver = tf.train.import_meta_graph('C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess,'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt')
    print('result',sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
值得注意的是在這個過程中兩次使用到的路徑名中調用的文件有所不同,model.ckpt.meta保存了計算圖的結構。而restore中的路徑與前面的例子完全相同。在上面的例子中result的name屬性爲'add:0',表示其是add這個操作的第一個輸出值。我們使用tf.get_default_graph().get_tensor_by_name('add:0')函數來獲取當前的計算圖和result變量。最後的輸出結果是[3]。

在上面的例子中,默認保存和加載了TF計算圖上定義的全部變量,但是有的時候可能只需要保存或者加載一部分變量。爲了實現這個功能,在聲明tf.trian.Saver的對象的時候,可以提供一個列表來指定需要保存或者加載的變量。比如
saver = tf.train.Saver([v1])
在這個過程中就只有變量v1會被保存和加載。如果在這個過程中,想要通過加載已經保存的模型,並且輸出v2或者result的值,都會報錯,而v1還是可以正常輸出[1]。

3.變量重命名

除了可以選取需要被加載的變量,tf.train.Saver類也支持在保存或者加載時給變量重命名。如果先前已經保存了v1和v2兩個變量,其name依次爲'v1'和'v2', 我們可以在加載的過程中,聲明變量的時候對這兩個變量進行重命名。該過程如下所示:
v1 = tf.Variable(tf.constant([1]), name = 'other-v1')
v2 = tf.Variable(tf.constant([2]), name = 'other-v2')
saver =  tf.train.Saver({'v1': v1, 'v2': v2})
with tf.Session() as sess2:
    saver.restore(sess2, 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt')
    print(sess2.run(v1))
我們分別爲v1和v2兩個變量重命名爲'other-v1', 'other-v2',在加載模型的時候,如果聲明對象的過程中tf.train.Saver()括號中爲空,則會報找不到變量的錯誤。爲了重命名,在括號中加入了一個字典,將原來name爲'v1'的變量保存到當前的v1變量中,而當前的v1變量的名稱爲‘other-v1’。v2則同理。當然在此處變量也可以以任意合法的標識符來定義,如v1_,v2_,只要保持統一,這樣都是合法的。
這樣可以方便使用變量的滑動平均值。只需要將影子變量映射到變量自身,那麼在使用訓練好的模型時就不需要再調用函數來獲取變量的滑動平均值了。爲了方便加載時重命名滑動平均變量,tf.train.ExponentialMovingAverage類提供了variables_to_restore函數來生成tf.train.Saver類所需要的變量重命名字典。
import tensorflow as tf
v = tf.Variable(0, name = 'v')
ema = tf.train.ExponentialMovingAverage(0.99)
saver = tf.train.Saver(ema.variables_to_restore())

4.其他

在測試或者離線預測時,只需要知道如何從神經網絡的輸入層經過前向傳播得到輸出層即可,而不需要變量初始化,模型保存等輔助結點的信息。在遷移學習的過程中,也會遇到類似的情況。TF中提供了convert_variables_to_constants函數,通過這個函數可以將計算圖中的變量及其取值通過常量的方式保存,這樣整個TF計算圖可以統一存放在一個文件中。下面的程序提供了一個例子。
import tensorflow as tf
from tensorflow.python.framework import  graph_util

sess = tf.InteractiveSession()

v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = 'v1')
v2 = tf.Variable(tf.constant(2.0, shape = [1]), name = 'v2')
result = v1 + v2

tf.initialize_all_variables().run()

#'''導出當前計算圖的GraphDef部分,只需要這個一個部分就可以完成從輸入層到輸出層的計算過程'''
graph_def = tf.get_default_graph().as_graph_def()
#將圖中的變量及其取值轉化爲常量,同時將圖中不必要的結點去掉(比如變量的初始化操作)
#如果只關心程序中定義的某些計算時,和這些計算無關的節點就沒有必要保存了。在下面一行代碼中,
#最後一個參數['add']給出了需要保存的節點名稱。add節點是上面定義的兩個變量相加的操作。
out_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
#將導出的模型存入文件
with tf.gfile.GFile('C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/combined_model.pb', 'wb') as f:
    f.write(out_graph_def.SerializeToString())
生成的文件截圖:

通過下面的程序可以直接計算定義的加法運算的結果。當只需要得到計算圖中某個節點的取值時,這提供了一個更加方便的方法。
import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/combined_model.pb'
    #讀取保存的模型文件,並且將文件解析成對應的GraphDef Protocol Buffer
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    #將graph_def中保存的圖加載到當前的圖中。retrun_elements = ['add:0']給出了返回張量的名稱
    #在保存的時候給出的是計算節點的名稱'add',在加載的時候給出的是張量的名稱,所以是‘add:0'
    result = tf.import_graph_def(graph_def, return_elements = ['add:0'])
    print(sess.run(result))

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章