一:用xx.npz文件初始化網絡
使用tensorpack框架的時候,發現官方提供的訓練好的權重文件是xx.npz格式的,我想將其某些層的參數用在自己的網絡中。
import os
import random
import tensorflow as tf
import numpy as np
PRE_IMANET_NPZ = 'XX.npz'
def convert_param_name(param):
# 獲取.npz文件裏的變量名稱
# 和網絡的變量名稱進行比較
# 得到 網絡變量名稱:文件裏的變量值 這樣的字典
# print('--> convert_param_name ...')
resnet_param = {}
for k in param.keys():
# print(k)
var_name = k.replace('W', 'weights')
var_name = var_name.replace('bn', 'BatchNorm')
resnet_param[var_name] = param[k]
return resnet_param
def initial_imagenet(sess, path_to_npz):
print('Initializing through npz file trained on ImageNet ...')
sess.run(tf.global_variables_initializer()) # 先初始化網絡 避免有些網絡的變量不存在在文件裏
param = np.load(path_to_npz, encoding='latin1') # 加載文件
param = convert_param_name(param) # 得到 網絡變量名稱:文件裏的變量值 這樣的字典
for var in tf.trainable_variables(): # 給網絡變量進行賦值
if var.name in param.keys():
sess.run(var.assign(param[var.name]))
# print(var.name, var.shape, param[var.name].shape)
# print(var.name, 'done')
else:
print(var.name, var.shape, 'not in trained weights ---------------------------------')
pass
def main():
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
random.seed(args.seed)
model = xxModel(num_classes=2)
with tf.Session() as sess:
initial_imagenet(sess=sess, path_to_npz=args.npz_file)
train(sess, model, train_set, val_set, args.checkpoint, **train_kwargs(args))
...
if __name__ == '__main__':
main()
二:保存訓練好的權重時重命名
我們都知道tensorflow在加載模型的時候是根據名字,我使用slim框架搭建模型進行了訓練,再將訓練好的模型加載到另一個用tensorpack寫的工程裏的時候,變量命名方式不同,無法成功加載。因此我想在保存模型的時候對變量重命名。
def change_var_list():
var_dict = {}
for var in tf.global_variables():
if 'Adam' not in var.name:
var_name = var.name
var_name = var_name.replace('BatchNorm', 'bn')
var_name = var_name.replace('weights', 'W')
var_name = var_name.replace('moving_mean', 'mean/EMA')
var_name = var_name.replace('moving_variance', 'variance/EMA')
var_name = var_name.replace('linear_C2/biases', 'linear_C2/b')
var_dict.update({var_name: var})
return var_dict
saver = tf.train.Saver(max_to_keep=meta_iters, var_list=change_var_list())
for i in range(step):
# train
# ...
if i % 100 == 0 or i == step - 1:
saver.save(sess, os.path.join(save_dir, 'model.ckpt'), global_step=i)
三:重命名.ckpt文件
# 讀取.ckpt文件中的變量名稱和對應值
from tensorflow.python import pywrap_tensorflow
checkpoint_path = './model.ckpt-200'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key))