将tensorflow保存的预训练模型读取为pytorch模型参数

用tensorflow的saver保存模型后,会有如下几个文件
在这里插入图片描述
我们可以通过graph.pbtxt看到tensorflow的计算图,方法如下:

import tensorflow as tf
import tensorflow.contrib.image
from tensorflow.python.platform import gfile
graph = tf.get_default_graph()
graphdef = graph.as_graph_def()
_ = tf.train.import_meta_graph("D:\EV-FlowNet-pth\data\log\saver\model.ckpt-600023.meta")
summary_write = tf.summary.FileWriter("./" , graph)

执行以上代码,会在当前路径下生成一个文件,然后在命令行运行:

tensorboard --logdir 刚才生成文件所在目录 --host 你的ip --port 你的端口

自己的ip一般设为127.0.0.1,默认端口一般为6006
然后就会看到输出一个url,在浏览器打开就可以看到计算图了。

如果要读取模型数据,执行以下代码:

import tensorflow as tf
reader = tf.train.NewCheckpointReader('D:\EV-FlowNet-pth\data\log\saver\model.ckpt-600023')
weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}

现在weights里即为所有参数,可以查看:

for k,v in weights:
	print(k,v)

k为参数名称,v为参数值
将所有参数按你所创建的pytorch模型的参数顺序放入一个Orderdict字典里,保存即可,这里要注意不同的参数名以及参数的各个维度是否相同。
在读取时直接使用model.load_state_dict()方法即可,传入参数为你所保存的字典。
注:
在使用batchnorm时的参数,pytorch和tensorflow有不同的命名法:
tensorflow中的gamma、beta:分别是仿射中的weight、bias,在pytorch中用weight和bias表示。
tensorflow中的moving_mean、moving_variance在pytorch中表示为running_mean、running_variance

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章