Tensorflow學習筆記:CNN篇(8)——Finetuning,模型更爲細化的保存與恢復

Tensorflow學習筆記:CNN篇(8)——Finetuning,模型更爲細化的保存與恢復


前序

— 對於模型的保存和恢復,前文已經做了介紹,然而讀者可能已經注意到,在設定的保存文件夾中有着4個不同的文件類型:
這裏寫圖片描述
可以得知,根據需要每個文件類型都有其不同的用處,但是僅僅知道這些還不夠,對於Tensorflow工作人員來說,需要更進一步瞭解不同文件所處的作用。


存儲文件的解讀

— 在介紹存儲文件之前,先對Saver類進行一下解釋。在不同的會話中,當需要將數據在硬盤上進行保存時,就可以使用Saver類。這個Saver構造類允許你去控制3個元素:

  • 目標(The target):設定目標。在分佈式架構的情況下,我們可以指定要計算哪個Tensorflow服務器或者“目標”
  • 圖(The graph):設置保存的圖。保存希望會話處理的圖。對於初學者來說,這裏有一個棘手的事情就是在Tensorflow中總有一個默認的圖,並且所有的操作都是在這個圖中首先進行,所以總是在“默認圖範圍”內。
  • 配置(The config):設置配置。可以使用ConfigProto參數來配置Tensorflow。

—Saver類可以處理圖中元數據和變量數據的保存和恢復,而我們唯一需要做的是,告訴Saver類需要保存哪個圖和哪些變量。在默認的情況下,Saver類能處理默認圖中包含的所有變量。但是,我們也可以創建出很多的Saver類,去保存想要的任何子圖。
—介紹完Saver類,對於模型存儲來說,這裏有4個文件類型,依次如下:

  • checkpoint:檢查點文件,記錄存儲文件名稱。
  • save_model.ckpt.data-00000-of-00001:等價於save_model.ckpt,權重存儲文件
  • save_model.ckpt.index:存儲權重目錄
  • save_model.ckpt.meta:模型的全部圖文件

在對模型進行保存和恢復時,Saver類將保存於圖像關聯的任何元數據,這意味着加載元檢查點還將恢復與圖相關聯的所有空變量、操作和集合。


代碼示例

現在拋開理論介紹而對模型進行恢復與處理。,由於Tensorflow將整體的“圖”文件存儲在meta後綴的文件中,而將權重存儲在ckpt後綴的文件中,在其具體使用時,對於模型權重的注入則是根據相應的名稱來進行,因此,如果需要對模型中不同的權重進行重新注入的話,那麼第一步就是需要賦予不同的權重以名稱。

with tf.variable_scope("var"):
        self.a_val = tf.Variable(tf.random_normal([1]),name="a_val")
        self.b_val = tf.Variable(tf.random_normal([1]),name="b_val")

這裏首先使用了tf.variable_scope對域進行了定義,之後在定義域內對輸入變量進行賦值。最終形成的名稱爲:

var/a_val

Step 1: 重新定義的線性迴歸類

首先是對於線性迴歸類的定義,在前面已經說了,需要對不同的變量或者佔位符以及不同的函數定義其在圖中的名稱,這裏爲了簡便,只定義了變量和佔位符的名稱:

import tensorflow as tf

class LineRegModel:

    def __init__(self):
        with tf.variable_scope("var"):
            self.a_val = tf.Variable(tf.random_normal([1]),name="a_val")
            self.b_val = tf.Variable(tf.random_normal([1]),name="b_val")
        self.x_input = tf.placeholder(tf.float32,name="input_placeholder")
        self.y_label = tf.placeholder(tf.float32,name="result_placeholder")
        self.y_output = tf.add(tf.multiply(self.x_input, self.a_val), self.b_val,name="output")
        self.loss = tf.reduce_mean(tf.pow(self.y_output - self.y_label, 2))

    def get_saver(self):
        return tf.train.Saver()

    def get_op(self):
        return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)

在程序中可以看到,這裏對每個變量或佔位符都設置了相應的名稱,而對變量域又設置了對應的域名。

Step 2: 重新對模型進行訓練

import tensorflow as tf
import numpy as np
import global_variable
import lineRegulation_model as model

train_x = np.random.rand(5)
train_y = 5 * train_x + 3.2   # y = 5 * x + 3
model = model.LineRegModel()

a_val = model.a_val
b_val = model.b_val

x_input = model.x_input
y_label = model.y_label

y_output = model.y_output

loss = model.loss
optimize = model.get_op()
saver = model.get_saver()

if __name__ == "__main__":
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    flag = True
    epoch = 0
    while flag:
        epoch += 1
        _ , loss_val = sess.run([optimize,loss],feed_dict={x_input:train_x,y_label:train_y})
        if loss_val < 1e-6:
            flag = False
    print(a_val.eval(sess) , "   ", b_val.eval(sess))
    print("-----------%d-----------"%epoch)
    print(a_val.op)
    saver.save(sess,global_variable.save_path)
    print("model save finished")
    sess.close()

這裏寫圖片描述
可以看到,其中的節點名被定義爲“var/a_val”,這是類中被定義是賦予的變量名稱。

Step 3: 模型的恢復

對於模型的恢復來說,需要首先恢復模型的整個圖文件,之後從圖文件中讀取相應的節點信息。

saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')

Saver方法先從圖中獲取了整個圖的信息,之後根據節點名稱將不同的變量或者佔位符重新按名稱賦值。

#讀取placeholder和最終的輸出結果
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')

input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')#最終輸出結果的tensor

而具體的權重恢復則需要在對話中完成。

with tf.Session() as sess:
    saver.restore(sess, './model/save_model.ckpt')

完整代碼

import tensorflow as tf

saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')

#讀取placeholder和最終的輸出結果
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')

input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')

with tf.Session() as sess:
    saver.restore(sess, './model/save_model.ckpt')
    result = sess.run(y_output, feed_dict={input_placeholder: [1]})
    print(result)
    print(sess.run(a_val))

讀者可能注意到,在程序中採用通過名稱獲取對應的變量值的時候,冒號的右邊有一個0符號,這是在Tensorflow的圖運行中爲了進行參數的複用而使用的標記類型,這裏讀者可以對其忽略而直接使用,程序運行的結果如下:
這裏寫圖片描述

Step 4: 恢復模型的特定值

如果要對模型的特定值進行恢復,同樣可以使用這個首先載入圖文件之後使用權重對其賦值的辦法。

import tensorflow as tf

saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')

graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')

y_output=graph.get_tensor_by_name('output:0')

with tf.Session() as sess:
    saver.restore(sess, './model/save_model.ckpt')
    print(sess.run(a_val))

可以看到這裏只定義了變量a_val,並通過相應的名稱將其重新獲取。這種方法可以獲取到模型中特定的變量或者節點的值,其最終結果如下:
這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章