【迁移学习】两种是否加载网络层的预训练参数方法

前言:我最近在研究迁移学习在医学图像中的应用,通过摸索掌握了两种是否加载网络层的预训练参数方法,具体而言就比如A卷积神经网络一共有a、b、c、d四个有权重参数的网络层,我可以任意选择迁移网络层的个数,比如我可以只迁移a或者c网络层的参数,也可以同时迁移a、b、c多个网络层的参数。

第一种实现方法(从模型本身入手)

例如inception_resnet_v2模型:

      with tf.variable_scope('Mixed_5b'):
        with tf.variable_scope('Branch_0'):
          tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
        with tf.variable_scope('Branch_1'):
          tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
          tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
                                      scope='Conv2d_0b_5x5')
        with tf.variable_scope('Branch_2'):
          tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
          tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
                                      scope='Conv2d_0b_3x3')
          tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
                                      scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
                                       scope='AvgPool_0a_3x3')
          tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
                                     scope='Conv2d_0b_1x1')
        net = tf.concat(
            [tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1], 3)

      if add_and_check_final('Mixed_5b', net): return net, end_points

        例如:这是其中的一个inception模块,每个网络层都有定义变量空间,例如整个inception模块的参数都命名为了Mixed_5b,然后其中的一个分支命名为了Branch_1,这个分支里的一个卷积层命名为了Conv2d_0b_5x5。所以我们就可以得到这个卷积层的参数命名为Mixed_5b/Branch_1/Conv2d_0b_5x5,然后通过相应的函数指定这个卷积层的预训练参数是加载还是不加载

第二种实现方法(从保存的参数文件入手)

但第一种方法有个瓶颈,有的时候模型写法会简化,使用repeat函数复用,例如inception_resnet_v2模型中:

net = slim.repeat(net, 9, block8, scale=0.20, activation_fn=activation_fn)

block8是之前已经写好的一个inception_resnet模块,这里通过repeat函数重复9次,大大减少了代码的冗余性,但是这样就无法从模型本身入手指定网络层的预训练参数是否加载了。

为此,我研究发现可以先找到保存的参数文件入手,我用的是TensorFlow框架,预训练好的参数文件是ckpt格式,通过一下代码就可以打印出每个变量名的空间:

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

model_dir = '../inception_resnet_v2.ckpt'

reader = pywrap_tensorflow.NewCheckpointReader(model_dir)
var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:
    print(key)  # tensor_name

打印结果例如

然后就可以通过相应的函数,选择这个变量参数是否加载了。

 

以上就是以TensorFlow框架、inception_resnet_v2模型演示的两种选择是否加载预训练网络参数的方法。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章