Tensorflow學習筆記:CNN篇(8)——Finetuning,模型更爲細化的保存與恢復
前序
— 對於模型的保存和恢復,前文已經做了介紹,然而讀者可能已經注意到,在設定的保存文件夾中有着4個不同的文件類型:
可以得知,根據需要每個文件類型都有其不同的用處,但是僅僅知道這些還不夠,對於Tensorflow工作人員來說,需要更進一步瞭解不同文件所處的作用。
存儲文件的解讀
— 在介紹存儲文件之前,先對Saver類進行一下解釋。在不同的會話中,當需要將數據在硬盤上進行保存時,就可以使用Saver類。這個Saver構造類允許你去控制3個元素:
- 目標(The target):設定目標。在分佈式架構的情況下,我們可以指定要計算哪個Tensorflow服務器或者“目標”
- 圖(The graph):設置保存的圖。保存希望會話處理的圖。對於初學者來說,這裏有一個棘手的事情就是在Tensorflow中總有一個默認的圖,並且所有的操作都是在這個圖中首先進行,所以總是在“默認圖範圍”內。
- 配置(The config):設置配置。可以使用ConfigProto參數來配置Tensorflow。
—Saver類可以處理圖中元數據和變量數據的保存和恢復,而我們唯一需要做的是,告訴Saver類需要保存哪個圖和哪些變量。在默認的情況下,Saver類能處理默認圖中包含的所有變量。但是,我們也可以創建出很多的Saver類,去保存想要的任何子圖。
—介紹完Saver類,對於模型存儲來說,這裏有4個文件類型,依次如下:
- checkpoint:檢查點文件,記錄存儲文件名稱。
- save_model.ckpt.data-00000-of-00001:等價於save_model.ckpt,權重存儲文件
- save_model.ckpt.index:存儲權重目錄
- save_model.ckpt.meta:模型的全部圖文件
在對模型進行保存和恢復時,Saver類將保存於圖像關聯的任何元數據,這意味着加載元檢查點還將恢復與圖相關聯的所有空變量、操作和集合。
代碼示例
現在拋開理論介紹而對模型進行恢復與處理。,由於Tensorflow將整體的“圖”文件存儲在meta後綴的文件中,而將權重存儲在ckpt後綴的文件中,在其具體使用時,對於模型權重的注入則是根據相應的名稱來進行,因此,如果需要對模型中不同的權重進行重新注入的話,那麼第一步就是需要賦予不同的權重以名稱。
with tf.variable_scope("var"):
self.a_val = tf.Variable(tf.random_normal([1]),name="a_val")
self.b_val = tf.Variable(tf.random_normal([1]),name="b_val")
這裏首先使用了tf.variable_scope對域進行了定義,之後在定義域內對輸入變量進行賦值。最終形成的名稱爲:
var/a_val
Step 1: 重新定義的線性迴歸類
首先是對於線性迴歸類的定義,在前面已經說了,需要對不同的變量或者佔位符以及不同的函數定義其在圖中的名稱,這裏爲了簡便,只定義了變量和佔位符的名稱:
import tensorflow as tf
class LineRegModel:
def __init__(self):
with tf.variable_scope("var"):
self.a_val = tf.Variable(tf.random_normal([1]),name="a_val")
self.b_val = tf.Variable(tf.random_normal([1]),name="b_val")
self.x_input = tf.placeholder(tf.float32,name="input_placeholder")
self.y_label = tf.placeholder(tf.float32,name="result_placeholder")
self.y_output = tf.add(tf.multiply(self.x_input, self.a_val), self.b_val,name="output")
self.loss = tf.reduce_mean(tf.pow(self.y_output - self.y_label, 2))
def get_saver(self):
return tf.train.Saver()
def get_op(self):
return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
在程序中可以看到,這裏對每個變量或佔位符都設置了相應的名稱,而對變量域又設置了對應的域名。
Step 2: 重新對模型進行訓練
import tensorflow as tf
import numpy as np
import global_variable
import lineRegulation_model as model
train_x = np.random.rand(5)
train_y = 5 * train_x + 3.2 # y = 5 * x + 3
model = model.LineRegModel()
a_val = model.a_val
b_val = model.b_val
x_input = model.x_input
y_label = model.y_label
y_output = model.y_output
loss = model.loss
optimize = model.get_op()
saver = model.get_saver()
if __name__ == "__main__":
sess = tf.Session()
sess.run(tf.global_variables_initializer())
flag = True
epoch = 0
while flag:
epoch += 1
_ , loss_val = sess.run([optimize,loss],feed_dict={x_input:train_x,y_label:train_y})
if loss_val < 1e-6:
flag = False
print(a_val.eval(sess) , " ", b_val.eval(sess))
print("-----------%d-----------"%epoch)
print(a_val.op)
saver.save(sess,global_variable.save_path)
print("model save finished")
sess.close()
可以看到,其中的節點名被定義爲“var/a_val”,這是類中被定義是賦予的變量名稱。
Step 3: 模型的恢復
對於模型的恢復來說,需要首先恢復模型的整個圖文件,之後從圖文件中讀取相應的節點信息。
saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')
Saver方法先從圖中獲取了整個圖的信息,之後根據節點名稱將不同的變量或者佔位符重新按名稱賦值。
#讀取placeholder和最終的輸出結果
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')
input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')#最終輸出結果的tensor
而具體的權重恢復則需要在對話中完成。
with tf.Session() as sess:
saver.restore(sess, './model/save_model.ckpt')
完整代碼
import tensorflow as tf
saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')
#讀取placeholder和最終的輸出結果
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')
input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')
with tf.Session() as sess:
saver.restore(sess, './model/save_model.ckpt')
result = sess.run(y_output, feed_dict={input_placeholder: [1]})
print(result)
print(sess.run(a_val))
讀者可能注意到,在程序中採用通過名稱獲取對應的變量值的時候,冒號的右邊有一個0符號,這是在Tensorflow的圖運行中爲了進行參數的複用而使用的標記類型,這裏讀者可以對其忽略而直接使用,程序運行的結果如下:
Step 4: 恢復模型的特定值
如果要對模型的特定值進行恢復,同樣可以使用這個首先載入圖文件之後使用權重對其賦值的辦法。
import tensorflow as tf
saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')
y_output=graph.get_tensor_by_name('output:0')
with tf.Session() as sess:
saver.restore(sess, './model/save_model.ckpt')
print(sess.run(a_val))
可以看到這裏只定義了變量a_val,並通過相應的名稱將其重新獲取。這種方法可以獲取到模型中特定的變量或者節點的值,其最終結果如下: