备忘:tensorflow关于网络权重

一:用xx.npz文件初始化网络

使用tensorpack框架的时候,发现官方提供的训练好的权重文件是xx.npz格式的,我想将其某些层的参数用在自己的网络中。

import os
import random

import tensorflow as tf
import numpy as np

PRE_IMANET_NPZ = 'XX.npz'


def convert_param_name(param):
    # 获取.npz文件里的变量名称
    # 和网络的变量名称进行比较
    # 得到 网络变量名称:文件里的变量值 这样的字典
    # print('--> convert_param_name ...')
    resnet_param = {}
    for k in param.keys():
        # print(k) 
        var_name = k.replace('W', 'weights')
        var_name = var_name.replace('bn', 'BatchNorm')
        resnet_param[var_name] = param[k]
    return resnet_param


def initial_imagenet(sess, path_to_npz):
    print('Initializing through npz file trained on ImageNet ...')
    sess.run(tf.global_variables_initializer())  # 先初始化网络 避免有些网络的变量不存在在文件里
    param = np.load(path_to_npz, encoding='latin1')  # 加载文件
    param = convert_param_name(param)  # 得到 网络变量名称:文件里的变量值 这样的字典
    for var in tf.trainable_variables():  # 给网络变量进行赋值
        if var.name in param.keys():
            sess.run(var.assign(param[var.name]))
            # print(var.name, var.shape, param[var.name].shape)
            # print(var.name, 'done')
        else:
            print(var.name, var.shape, 'not in trained weights ---------------------------------')
            pass

def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    random.seed(args.seed)

    model = xxModel(num_classes=2)

    with tf.Session() as sess:
        initial_imagenet(sess=sess, path_to_npz=args.npz_file)
        train(sess, model, train_set, val_set, args.checkpoint, **train_kwargs(args))
        ...


if __name__ == '__main__':
    main()

 

二:保存训练好的权重时重命名

我们都知道tensorflow在加载模型的时候是根据名字,我使用slim框架搭建模型进行了训练,再将训练好的模型加载到另一个用tensorpack写的工程里的时候,变量命名方式不同,无法成功加载。因此我想在保存模型的时候对变量重命名。

def change_var_list():
    var_dict = {}
    for var in tf.global_variables():
        if 'Adam' not in var.name:
            var_name = var.name
            var_name = var_name.replace('BatchNorm', 'bn')
            var_name = var_name.replace('weights', 'W')
            var_name = var_name.replace('moving_mean', 'mean/EMA')
            var_name = var_name.replace('moving_variance', 'variance/EMA')
            var_name = var_name.replace('linear_C2/biases', 'linear_C2/b')
            var_dict.update({var_name: var})
    return var_dict


saver = tf.train.Saver(max_to_keep=meta_iters, var_list=change_var_list())

for i in range(step):
    # train
    # ...
    
    if i % 100 == 0 or i == step - 1:
        saver.save(sess, os.path.join(save_dir, 'model.ckpt'), global_step=i)

 

三:重命名.ckpt文件

# 读取.ckpt文件中的变量名称和对应值
from tensorflow.python import pywrap_tensorflow

checkpoint_path = './model.ckpt-200'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

 

发布了34 篇原创文章 · 获赞 5 · 访问量 1万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章