【遷移學習】兩種是否加載網絡層的預訓練參數方法

前言:我最近在研究遷移學習在醫學圖像中的應用,通過摸索掌握了兩種是否加載網絡層的預訓練參數方法,具體而言就比如A卷積神經網絡一共有a、b、c、d四個有權重參數的網絡層,我可以任意選擇遷移網絡層的個數,比如我可以只遷移a或者c網絡層的參數,也可以同時遷移a、b、c多個網絡層的參數。

第一種實現方法(從模型本身入手)

例如inception_resnet_v2模型:

      with tf.variable_scope('Mixed_5b'):
        with tf.variable_scope('Branch_0'):
          tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
        with tf.variable_scope('Branch_1'):
          tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
          tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
                                      scope='Conv2d_0b_5x5')
        with tf.variable_scope('Branch_2'):
          tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
          tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
                                      scope='Conv2d_0b_3x3')
          tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
                                      scope='Conv2d_0c_3x3')
        with tf.variable_scope('Branch_3'):
          tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
                                       scope='AvgPool_0a_3x3')
          tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
                                     scope='Conv2d_0b_1x1')
        net = tf.concat(
            [tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1], 3)

      if add_and_check_final('Mixed_5b', net): return net, end_points

        例如:這是其中的一個inception模塊,每個網絡層都有定義變量空間,例如整個inception模塊的參數都命名爲了Mixed_5b,然後其中的一個分支命名爲了Branch_1,這個分支裏的一個卷積層命名爲了Conv2d_0b_5x5。所以我們就可以得到這個卷積層的參數命名爲Mixed_5b/Branch_1/Conv2d_0b_5x5,然後通過相應的函數指定這個卷積層的預訓練參數是加載還是不加載

第二種實現方法(從保存的參數文件入手)

但第一種方法有個瓶頸,有的時候模型寫法會簡化,使用repeat函數複用,例如inception_resnet_v2模型中:

net = slim.repeat(net, 9, block8, scale=0.20, activation_fn=activation_fn)

block8是之前已經寫好的一個inception_resnet模塊,這裏通過repeat函數重複9次,大大減少了代碼的冗餘性,但是這樣就無法從模型本身入手指定網絡層的預訓練參數是否加載了。

爲此,我研究發現可以先找到保存的參數文件入手,我用的是TensorFlow框架,預訓練好的參數文件是ckpt格式,通過一下代碼就可以打印出每個變量名的空間:

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

model_dir = '../inception_resnet_v2.ckpt'

reader = pywrap_tensorflow.NewCheckpointReader(model_dir)
var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:
    print(key)  # tensor_name

打印結果例如

然後就可以通過相應的函數,選擇這個變量參數是否加載了。

 

以上就是以TensorFlow框架、inception_resnet_v2模型演示的兩種選擇是否加載預訓練網絡參數的方法。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章