備忘:tensorflow關於網絡權重

一:用xx.npz文件初始化網絡

使用tensorpack框架的時候,發現官方提供的訓練好的權重文件是xx.npz格式的,我想將其某些層的參數用在自己的網絡中。

import os
import random

import tensorflow as tf
import numpy as np

PRE_IMANET_NPZ = 'XX.npz'


def convert_param_name(param):
    # 獲取.npz文件裏的變量名稱
    # 和網絡的變量名稱進行比較
    # 得到 網絡變量名稱:文件裏的變量值 這樣的字典
    # print('--> convert_param_name ...')
    resnet_param = {}
    for k in param.keys():
        # print(k) 
        var_name = k.replace('W', 'weights')
        var_name = var_name.replace('bn', 'BatchNorm')
        resnet_param[var_name] = param[k]
    return resnet_param


def initial_imagenet(sess, path_to_npz):
    print('Initializing through npz file trained on ImageNet ...')
    sess.run(tf.global_variables_initializer())  # 先初始化網絡 避免有些網絡的變量不存在在文件裏
    param = np.load(path_to_npz, encoding='latin1')  # 加載文件
    param = convert_param_name(param)  # 得到 網絡變量名稱:文件裏的變量值 這樣的字典
    for var in tf.trainable_variables():  # 給網絡變量進行賦值
        if var.name in param.keys():
            sess.run(var.assign(param[var.name]))
            # print(var.name, var.shape, param[var.name].shape)
            # print(var.name, 'done')
        else:
            print(var.name, var.shape, 'not in trained weights ---------------------------------')
            pass

def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    random.seed(args.seed)

    model = xxModel(num_classes=2)

    with tf.Session() as sess:
        initial_imagenet(sess=sess, path_to_npz=args.npz_file)
        train(sess, model, train_set, val_set, args.checkpoint, **train_kwargs(args))
        ...


if __name__ == '__main__':
    main()

 

二:保存訓練好的權重時重命名

我們都知道tensorflow在加載模型的時候是根據名字,我使用slim框架搭建模型進行了訓練,再將訓練好的模型加載到另一個用tensorpack寫的工程裏的時候,變量命名方式不同,無法成功加載。因此我想在保存模型的時候對變量重命名。

def change_var_list():
    var_dict = {}
    for var in tf.global_variables():
        if 'Adam' not in var.name:
            var_name = var.name
            var_name = var_name.replace('BatchNorm', 'bn')
            var_name = var_name.replace('weights', 'W')
            var_name = var_name.replace('moving_mean', 'mean/EMA')
            var_name = var_name.replace('moving_variance', 'variance/EMA')
            var_name = var_name.replace('linear_C2/biases', 'linear_C2/b')
            var_dict.update({var_name: var})
    return var_dict


saver = tf.train.Saver(max_to_keep=meta_iters, var_list=change_var_list())

for i in range(step):
    # train
    # ...
    
    if i % 100 == 0 or i == step - 1:
        saver.save(sess, os.path.join(save_dir, 'model.ckpt'), global_step=i)

 

三:重命名.ckpt文件

# 讀取.ckpt文件中的變量名稱和對應值
from tensorflow.python import pywrap_tensorflow

checkpoint_path = './model.ckpt-200'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

 

發佈了34 篇原創文章 · 獲贊 5 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章