將tensorflow保存的預訓練模型讀取爲pytorch模型參數

用tensorflow的saver保存模型後,會有如下幾個文件
在這裏插入圖片描述
我們可以通過graph.pbtxt看到tensorflow的計算圖,方法如下:

import tensorflow as tf
import tensorflow.contrib.image
from tensorflow.python.platform import gfile
graph = tf.get_default_graph()
graphdef = graph.as_graph_def()
_ = tf.train.import_meta_graph("D:\EV-FlowNet-pth\data\log\saver\model.ckpt-600023.meta")
summary_write = tf.summary.FileWriter("./" , graph)

執行以上代碼,會在當前路徑下生成一個文件,然後在命令行運行:

tensorboard --logdir 剛纔生成文件所在目錄 --host 你的ip --port 你的端口

自己的ip一般設爲127.0.0.1,默認端口一般爲6006
然後就會看到輸出一個url,在瀏覽器打開就可以看到計算圖了。

如果要讀取模型數據,執行以下代碼:

import tensorflow as tf
reader = tf.train.NewCheckpointReader('D:\EV-FlowNet-pth\data\log\saver\model.ckpt-600023')
weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}

現在weights裏即爲所有參數,可以查看:

for k,v in weights:
	print(k,v)

k爲參數名稱,v爲參數值
將所有參數按你所創建的pytorch模型的參數順序放入一個Orderdict字典裏,保存即可,這裏要注意不同的參數名以及參數的各個維度是否相同。
在讀取時直接使用model.load_state_dict()方法即可,傳入參數爲你所保存的字典。
注:
在使用batchnorm時的參數,pytorch和tensorflow有不同的命名法:
tensorflow中的gamma、beta:分別是仿射中的weight、bias,在pytorch中用weight和bias表示。
tensorflow中的moving_mean、moving_variance在pytorch中表示爲running_mean、running_variance

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章