Tensorflow 學習筆記之 共享變量(Sharing Variables)

Tensorflow 學習筆記之 共享變量(Sharing Variables)


最近兩年,谷歌撐腰的深度學習框架Tensorflow發展地如日中天,雖然17年pytorch的出現略微“打壓”了一些TF的勢頭,但TF在深度學習界的地位還是難以撼動的,github上TF的收藏量一直穩在深度學習中前二的位置。個人在4月份開始接觸TF,寫分類、超分辨網絡不亦樂乎。然而,最近從越來越多的TF github項目中看到了人們都在使用一個叫“共享變量”的機制管理變量,已經基本學會簡單TF語法的我,今天決定好好研究一下這個功能。

變量管理的問題

設想你要寫一個分類網絡,結構是“卷積->ReLU->Pooling->卷積->ReLU->Pooling->展平->全連接->ReLU->全連接->Softmax”。由於網絡實在太簡單了,寫起來完全不需要過多的思考。可能你是這麼寫的(例子出自TF官網:http://tensorflow.org/tutorials/mnist/pros/index.html):

def weight_variable(shape):
    return tf.Variable(tf.truncated_normal(shape, stddev=0.1))
def bias_variable(shape):
    return tf.Variable(tf.constant(0.1, shape=shape))

W_conv1 = weight_variable([5, 5, 3, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(tf.nn.conv2d(...))
h_pool1 = tf.nn.max_pool(h_conv1,...)

W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(tf.nn.conv2d(...))
h_pool2 = tf.nn.max_pool(h_conv2,...)

W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_flat = tf.reshape(...)
h_fc1 = tf.nn.relu(tf.matmul(...))

從中應該可以看出來,如果需要添加捲積層或全連接層,需要額外定義相應的權重w和偏置b。因此就有了[W_conv1,b_conv1,W_conv2,b_conv2,…]這一串變量信息。
那麼問題來了,如果讓你寫一個19層的VGG網絡,甚至是上百層的Resnet呢?這種定義方法顯然是行不通的,等手動把[W_conv1,b_conv1,W_conv2,b_conv2,…]這些東西輸完,估計也對TF喪失興趣了。你可能會想到這種循環的方法:

def layer(shape, ...):
    w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
    b = tf.Variable(tf.constant(0.1, shape=shape))
    return tf.nn.relu(tf.nn.conv2d(...))

for i in range(19):
    ...
    x = layer(shape, ...)
    ...

這樣的確就不用一個個寫[w1,w2,w3,w4,….]這些變量了,從某種程度上來看確實解放了雙手。但是,如果我現在想讀取第8個卷積層中w和b的數值,有沒有什麼簡單的方法呢?再或者我想把這個網絡中的參數轉移到另一個完全相同的網絡中使用呢?雖然你可以再定義一組列表var,在每次新定義變量後var.append(w),但從管理變量和傳輸變量的角度來看依舊不是很方便。

萬幸的是,TF早就想到了這一點,並且提供Variable Scope機制來幫助管理變量。有了這個工具,就再也不用爲變量的定義和共享傷腦筋了。

常用函數

  • tf.get_variable():和tf.Variable類似,該函數也是爲了創建一個變量。參數有:

    • name:變量名稱
    • initializer:初始化值
    • trainable:是否可訓練
  • tf.variable_scope():創建一個變量域,相當於在變量空間下打開一個文件夾。一般和tf.get_variable()組合使用,一種常用的用法如下:

import tensorflow as tf

with tf.variable_scope('cnn'):
    with tf.variable_scope('conv1'):
        w = tf.get_variable(
            initializer = tf.truncated_normal([3,3,3,32], stddev=0.1), 
            trainable=True, name = 'w')
        b = tf.get_variable(
            initializer = tf.zeros([32]), 
            trainable=True, name = 'b')

print(w.name)
print(b.name)

結果爲

cnn/conv1/w:0
cnn/conv1/b:0

實例

本章以簡單的MNIST識別爲例,來看看tf.get_variable()和tf.variable_scope()在訓練時能帶給大家怎樣的方便。

模型文件

首先,創建一個py文件,專門存放生成模型的代碼,叫做“cnnmodel.py”。其中定義一下權重和偏置的初始化函數:

import tensorflow as tf

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.get_variable(initializer = initial, trainable=True, name = 'w')

def bias_variable(shape):
    initial = tf.zeros(shape)
    return tf.get_variable(initializer = initial, trainable=True, name = 'b')

接着,進一步定義卷積層、全連接層等操作,這樣可以省去很多重複的字符:

def conv2d(x, W_shape):
    W = weight_variable(W_shape)
    B = bias_variable(W_shape[-1])
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') + B

def ann(x, W_shape):
    W = weight_variable(W_shape)
    B = bias_variable(W_shape[-1])
    return tf.matmul(x, W) + B

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

個人還是比較喜歡在函數裏面定義變量。每次調用函數,創建的都是局部變量,即使同名也不會有衝突。
然後就是創建CNN模型了:

def cnnmodel(inp, keep_prob):
    with tf.variable_scope('cnn'):

        with tf.variable_scope('conv1'):
            h_conv1 = tf.nn.relu(conv2d(inp, [5, 5, 1, 32]))
            h_pool1 = max_pool_2x2(h_conv1)

        with tf.variable_scope('conv2'):
            h_conv2 = tf.nn.relu(conv2d(h_pool1, [5, 5, 32, 64]))
            h_pool2 = max_pool_2x2(h_conv2)

        with tf.variable_scope('fc1'):
            h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
            h_fc1 = tf.nn.relu(ann(h_pool2_flat, [7 * 7 * 64, 1024]))
            h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

        with tf.variable_scope('fc2'):
            y_conv = tf.nn.softmax(ann(h_fc1_drop, [1024, 10]))

    var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='cnn')
    return y_conv, var

這樣寫的結果就是:在變量空間中有一個總文件夾叫“cnn”,下面有許多子文件夾“conv1”、“conv2”、“fc1”、“fc2”。每個子文件夾下都有“w”、“b”兩個變量。最後的tf.get_collection()就是爲了把在“cnn”目錄下的變量集合起來,一步到位。是不是比自己定義列表一個一個.append()方便多了。

訓練文件

此文件基本照搬Tensorflow官方教程的文檔,變動不大,唯一有區別的就是在調用cnnmodel時額外輸出了網絡變量。

import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
from cnnmodel import cnnmodel
sess = tf.InteractiveSession()

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
keep_prob = tf.placeholder("float")
x_image = tf.reshape(x, [-1,28,28,1])

y_conv, var = cnnmodel(x_image, keep_prob)

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy, var_list = var)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())

for i in range(2000):
    batch = mnist.train.next_batch(50)
    train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
    if i % 500 == 0 or i == 1999:
        train_accuracy = accuracy.eval(feed_dict={
            x:batch[0], y_: batch[1], keep_prob: 1.0})
        print("step %d, training accuracy %g"%(i, train_accuracy))
        saver = tf.train.Saver()
        saver.save(sess, 'backup/latest')

導入數據

無更改,直接調用即可

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

重新訓練 / 測試文件

在訓練文件中,最後幾行saver.save()把當前訓練的模型導出到backup文件夾下了。該文件可以導入最後訓練的網絡參數,方便做一些測試。可以用這種方法導入全局的參數:

if tf.train.get_checkpoint_state('backup/'):
    print('\nfound\n')
    saver = tf.train.Saver()
    saver.restore(sess, 'backup/latest')

如果你只想導入某一層參數的話,之前的變量管理就幫上忙了:

var_ = tf.global_variables()
net_var = [var for var in var_ if "conv1" in var.name]
if tf.train.get_checkpoint_state('backup/'):
    print('\nfound\n')
    saver = tf.train.Saver(net_var)
    saver.restore(sess, 'backup/latest')

如果不信,可以用print(sess.run(net_var[1]))這種方法打印出參數值,來看看導入的參數和訓練文件中導出的是不是一樣。
要提醒一點,導入參數一定要放在sess.run(tf.global_variables_initializer())之後,否則你剛把變量值導好,一個變量初始化過來又變成預設好的初始值了。
完整文件如下:

import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
from cnnmodel import cnnmodel
sess = tf.InteractiveSession()

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
keep_prob = tf.placeholder("float")
x_image = tf.reshape(x, [-1,28,28,1])

y_conv, var = cnnmodel(x_image, keep_prob)

var_ = tf.global_variables()
net_var = [var for var in var_ if "cnn" in var.name]

sess.run(tf.global_variables_initializer())

if tf.train.get_checkpoint_state('backup/'):
    print('\nfound\n')
    saver = tf.train.Saver()
    saver.restore(sess, 'backup/latest')
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章