乾貨!如何修改在TensorFlow框架下訓練保存的模型參數名稱

乾貨!如何修改在TensorFlow框架下訓練保存的模型參數名稱

爲何要修改TensorFlow訓練的模型參數名?

在TensorFlow框架下的深度學習程序中,我們將訓練得到的模型參數進行保存。在我們進行某些訓練任務時,也許要從已經保存的預訓練模型中載入參數,或者將TensorFlow框架訓練得到的參數轉換到其他框架使用。在進行上述操作的時候,有可能需要將已訓練的模型參數(尤其是參數名稱)做出改變。

舉個例子,筆者最近在做實驗,打算使用DeepLab V2作基準(baseline),並下載預訓練的DeepLab V2模型參數進行自己的模型的初始化,然後進行微調(finetune)。由於筆者的任務跟遷移學習比較相關,在程序中需要在某個參數域(variable_scope)下大量重用(reuse)參數。可是,網上下載的DeepLab V2模型不會考慮到筆者自己設置的參數域。因此,在使用下載的模型參數,使用鍵值對的方式進行參數初始化時,會報錯(參數名匹配不上)。因此,筆者需要修改下載的預訓練模型中的參數名稱,在每個參數前面加上筆者在自己的程序中設定的參數域名。這樣在程序中纔可以既在指定的參數域下重用參數,又可以使用預訓練的參數進行初始化

筆者觀察到,介紹TensorFlow保存讀取參數的博客很多,但是很少有介紹修改已保存參數的博客。而修改參數在某些情況下是會使用到的。因此在本篇博客中,筆者就介紹怎麼在程序中對已保存的模型參數(名稱)進行修改,在筆者的程序中,也可以對參數的值進行修改。

如何修改TensorFlow訓練保存的參數名?

在筆者最開始進行參數名稱修改摸索的時候,網上的資源真是少之又少,搜索了一段時間後。筆者看到一篇文檔進行了介紹並附帶了代碼,大家可以移步這篇知乎專欄:Tensorflow修改已訓練模型變量名字的方法。這篇專欄對筆者的幫助比較大,筆者也借鑑了裏面的少量代碼。

可是,筆者覺得上述專欄裏面的做法比較繁瑣。因爲,從代碼裏面可以看到,在修改模型參數的時候,進行了讀取數據流圖(Graph)的操作。可是,在我們使用預訓練模型初始化的時候,是按照字典,即鍵值對的方式進行初始化的。具體解釋就是按照我們定義的參數名稱,去已保存的模型參數裏面讀取對應的值來初始化。因此,筆者認爲沒有必要專門讀取數據流圖,並進行了更簡潔的嘗試。

在放出代碼之前,筆者先介紹一下用到的兩個重要的接口:

  1. tf.contrib.framework.list_variables。將已保存參數的(名稱,形狀)以列表的形式返回。在更新的TensorFlow版本中,該接口已經被整合到了tf.train.list_variables裏面。
  2. tf.contrib.framework.load_variable。可以傳入名稱,返回讀取的已保存參數的值。在更新的TensorFlow版本中,該接口已經被整合到了tf.train.load_variable裏面。

在修改保存的參數名稱時,做法分爲以下6步:

  1. 使用list_variables函數逐個讀出已保存的參數名稱
  2. 使用load_variable函數逐個讀取已保存的參數值
  3. 逐個修改參數名稱
  4. 使用已修改的參數名稱,結合tf.Variable函數逐個重建參數
  5. 將已重建的參數逐個加入新參數列表
  6. 使用tf.train.Saver().save將新參數列表寫入硬盤

下面放出筆者的代碼,在代碼中,筆者給DeepLab V2預訓練的模型參數全加上了前綴“deeplab_v2”。在這裏筆者使用的還是許久之前的DeepLab預訓練模型,參數保存還是一個ckpt文件(deeplab_resnet.ckpt)。代碼如下:

import tensorflow as tf
import argparse
import os

parser = argparse.ArgumentParser(description='')

parser.add_argument("--checkpoint_path", default='../deeplab_resnet/deeplab_resnet.ckpt', help="restore ckpt") #原參數路徑
parser.add_argument("--new_checkpoint_path", default='../deeplab_resnet_altered/', help="path_for_new ckpt") #新參數保存路徑
parser.add_argument("--add_prefix", default='deeplab_v2/', help="prefix for addition") #新參數名稱中加入的前綴名

args = parser.parse_args()


def main():
    if not os.path.exists(args.new_checkpoint_path):
        os.makedirs(args.new_checkpoint_path)
    with tf.Session() as sess:
        new_var_list=[] #新建一個空列表存儲更新後的Variable變量
        for var_name, _ in tf.contrib.framework.list_variables(args.checkpoint_path): #得到checkpoint文件中所有的參數(名字,形狀)元組
            var = tf.contrib.framework.load_variable(args.checkpoint_path, var_name) #得到上述參數的值

            new_name = var_name
            new_name = args.add_prefix + new_name #在這裏加入了名稱前綴,大家可以自由地作修改

            #除了修改參數名稱,還可以修改參數值(var)

            print('Renaming %s to %s.' % (var_name, new_name))
            renamed_var = tf.Variable(var, name=new_name) #使用加入前綴的新名稱重新構造了參數
            new_var_list.append(renamed_var) #把賦予新名稱的參數加入空列表

        print('starting to write new checkpoint !')
        saver = tf.train.Saver(var_list=new_var_list) #構造一個保存器
        sess.run(tf.global_variables_initializer()) #初始化一下參數(這一步必做)
        model_name = 'deeplab_resnet_altered' #構造一個保存的模型名稱
        checkpoint_path = os.path.join(args.new_checkpoint_path, model_name) #構造一下保存路徑
        saver.save(sess, checkpoint_path) #直接進行保存
        print("done !")

if __name__ == '__main__':
    main()

在終端下面運行一下代碼:
在這裏插入圖片描述
可以看到參數名稱都被重置了,加上了前綴“deeplab_v2”:
在這裏插入圖片描述
在代碼中設定的保存文件夾下,能夠查看已保存的新參數名稱的模型參數:
在這裏插入圖片描述
由於後來的TensorFlow框架在保存模型時已經放棄了保存單個ckpt文件的做法,因此都是得到4個文件,如上所示。

然後我們就可以在代碼中愉快地使用新參數名稱的模型進行初始化啦~

loader = tf.train.Saver(var_list=restore_vars) #設置一下要初始化哪些參數
checkpoint = tf.train.latest_checkpoint(args.checkpoint_path) #保存的新參數名的模型路徑
loader.restore(sess, ckpt_path) #初始化模型參數

到這裏,本篇博文就接近尾聲了。本篇博文主要講述瞭如何修改TensorFlow框架下訓練的參數名稱,核心還是找出參數名->更改參數名->重建參數->保存。筆者也衷心希望本篇博客能對大家的科研與工作有幫助。

歡迎閱讀筆者後續博客,各位讀者朋友的支持與鼓勵是我最大的動力

written by jiong
謙,亨,君子有終。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章