深度學習之卷積:如果卷積核被初始化爲0

前言

  這幾天面試遇到了這樣一個問題,如果卷積層的權重被賦值爲0,會發生什麼?

  解決這個問題我們首先定義一下在神經網絡中的基本權重和偏置的初始化情況,在TensorFlow中,權重一般由用戶初始化,可選擇很多初始化方式,如glorot_normal_initializer()等,但是偏置在默認的情況下一般初始化爲0,具體可以參考tf.layers.conv2dtf.layers.dense兩個函數,它們都默認偏置被初始化爲0。那麼我們也遵守這個設計範式,將卷積核初始化爲0,偏置也初始化爲0。

分析

  前面的初始化條件十分奇怪,因爲將所有的參數初始化爲0之後,這一層的輸出也必然是0,那麼在下一層當中,不管卷積核是什麼樣的參數分佈,經過卷積之後的結果也還是0,由於偏置默認是0,所以輸出依然是0,因此這種情況下,整個網絡的輸出就是0。

  考慮到網絡的數據需要在反向傳播的過程中進行更新,而計算參數的梯度以及前行傳遞的誤差的過程中需要使用每一層的輸入(而大部分層的輸入都是0),因此,無法進行梯度的更新。

參數的梯度更新

  考慮某一層的卷積核參數和偏置參數。

  如果這一層位於卷積核爲0的層之前,那麼,從自動微分的角度來看,當我們選擇其中一個參數來進行一個微小的變化時,這個變化在後面的卷積核爲0的卷積中就被忽略掉了,那麼反映在網絡最後的損失函數上也就是沒有變化,那麼也就是說這個時候該參數的梯度是0,也就是無法更新。

  如果這一層位於卷積核爲0的層之後,依然從自動微分的角度來看,當我們選擇卷積核中的一個參數進行一個微小的變化的時候,由於該層的輸入是0,那麼即使我的參數有了一點變化,反映在最後的損失函數上也就是沒有變化,也就是說這個時候該參數的梯度是0,所以卷積核的參數無法更新。偏置上一開始對損失函數沒有幫助,而在反向傳播的過程中需要使用到這一層的輸入(是0),因此我們向前傳遞的誤差矩陣都是0,所以梯度根本無法進行更新。

代碼

  代碼中使用了不變的樣本進行訓練,目的是爲了和其他的初始化方法作比較,屬於控制變量的方法。實際上不管採用怎樣的模型輸入,結果是一致的。

# coding=utf-8
# python3

import tensorflow as tf
from tensorflow.keras.datasets import mnist

import numpy as np


class DataPipeline():
    def __init__(self):
        (image_train, label_train), (image_test, label_test) = mnist.load_data()
        self.image_train = np.array(image_train)
        self.image_test = np.array(image_test)

        self.label_train = np.array(label_train)
        self.label_test = np.array(label_test)

    def next(self, n=1, tag="train"):
        if tag == 'train':
            length = len(self.image_train)
            index = np.random.choice(np.arange(length), n)
            images = self.image_train[index]
            labels = self.label_train[index]
            return np.reshape(images, [n, -1]), labels
        if tag == 'test':
            length = len(self.image_test)
            index = np.random.choice(np.arange(length), n)
            images = self.image_test[index]
            labels = self.label_test[index]
            return np.reshape(images, [n, -1]), labels

    def fixed(self, n=50, tag='train'):
        if tag == 'train':
            images = self.image_train[:n]
            labels = self.label_train[:n]

            return np.reshape(images, [n, -1]), labels
        if tag == 'test':
            images = self.image_test[:n]
            labels = self.label_test[:n]

            return np.reshape(images, [n, -1]), labels


def conv(x):
    tf.set_random_seed(1)

    x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_initializer=tf.glorot_normal_initializer(),
                         bias_initializer=tf.zeros_initializer(),
                         name='conv1',
                         reuse=tf.AUTO_REUSE)

    x = tf.layers.max_pooling2d(x, 2, 2, 'VALID')

    ##############################
    # This is a baseline where every variable is initialized with a different value
    ##############################

    # x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
    #                      activation=tf.nn.relu,
    #                      use_bias=True,
    #                      kernel_initializer=tf.glorot_normal_initializer(),
    #                      bias_initializer=tf.zeros_initializer(),
    #                      name='conv2',
    #                      reuse=tf.AUTO_REUSE)

    ##############################
    # This is a modification where the kernel and bias arae initialize with 0
    ##############################

    x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_initializer=tf.zeros_initializer(),
                         bias_initializer=tf.zeros_initializer(),
                         name='conv2',
                         reuse=tf.AUTO_REUSE)
    #

    x = tf.layers.max_pooling2d(x, 2, 2, 'VALID')

    tf.set_random_seed(2)

    x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_initializer=tf.glorot_normal_initializer(),
                         bias_initializer=tf.zeros_initializer(),
                         name='conv3')

    x = tf.layers.conv2d(x, 32, 3, 1, 'SAME',
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_initializer=tf.glorot_normal_initializer(),
                         bias_initializer=tf.zeros_initializer(),
                         name='conv4')

    return x


def fc(x):
    x = tf.layers.flatten(x)

    x = tf.layers.dense(x, 128, activation=tf.nn.relu,
                        use_bias=True, 
                        kernel_initializer=tf.glorot_normal_initializer(), 
                        bias_initializer=tf.zeros_initializer(),
                        name='fc1')

    x = tf.layers.dense(x, 10, activation=None,
                        use_bias=True, 
                        kernel_initializer=tf.glorot_normal_initializer(), 
                        bias_initializer=tf.zeros_initializer(),
                        name='fc2')

    return x


def main():
    dataset = DataPipeline()

    image = tf.placeholder(tf.float32, shape=[None, 784])
    label = tf.placeholder(tf.int64, shape=[None])

    x = tf.reshape(image, [-1, 28, 28, 1])

    x = conv(x)

    x = fc(x)

    names = [i.name for i in tf.all_variables()]

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=x))

    train_step = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)

    correct_prediction = tf.equal(tf.argmax(x, 1), label)

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    sess = tf.Session()

    sess.run(tf.global_variables_initializer())

    batch = dataset.next(1000, 'test')
    print("test accuracy %g" % sess.run(accuracy, feed_dict={image: batch[0], label: batch[1]}))

    #####################
    # Extrac the all vars
    #####################

    vars = {}

    for i in names:
        vars[i] = sess.run(tf.get_default_graph().get_tensor_by_name(i))

    batch = dataset.fixed(50)
    for i in range(1000):

        if i % 100 == 0:
            train_accuracy = sess.run(accuracy, feed_dict={image: batch[0], label: batch[1]})
            print("step %d,train accuracy %g" % (i, train_accuracy))
            sess.run(train_step, feed_dict={image: batch[0], label: batch[1]})

    batch = dataset.fixed(1000, 'test')
    print("test accuracy %g" % sess.run(accuracy, feed_dict={image: batch[0], label: batch[1]}))

    #####################
    # Extrac the vars again
    #####################

    for i in names:
        vars[i] -= sess.run(tf.get_default_graph().get_tensor_by_name(i))

    for i in names:
        print(i)
        print(vars[i])


if __name__ == '__main__':
    main()

結果
test accuracy 0.11
step 0,train accuracy 0.08
step 100,train accuracy 0.14
step 200,train accuracy 0.14
step 300,train accuracy 0.14
step 400,train accuracy 0.14
step 500,train accuracy 0.14
step 600,train accuracy 0.14
step 700,train accuracy 0.14
step 800,train accuracy 0.14
step 900,train accuracy 0.14
test accuracy 0.126
conv1/kernel:0
[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]


 [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]


 [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

  [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
    0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]]
conv1/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
conv2/kernel:0
[[[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]]
conv2/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
conv3/kernel:0
[[[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]]
conv3/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
conv4/kernel:0
[[[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]


 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]]
conv4/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
fc1/kernel:0
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
fc1/bias:0
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
fc2/kernel:0
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
fc2/bias:0
[ 1.9999105e-05 -3.9998205e-05  1.9999105e-05 -3.9998202e-05
  1.9999103e-05  1.9999105e-05  1.4156103e-11  1.9999105e-05
  1.9999105e-05 -3.9998198e-05]

  從結果中可以看到,只有最後的偏置項得到了更新,因爲這一項是直接和輸出和標籤label相關的,所以它可以得到更新,向前的數據則統統無法更新,這樣的模型的表現力只會有十分有限的提升,無法滿足工作生產中的需要。因此,在偏置項默認設置爲0的情況下,卷積核和全連接層的權重千萬不可以設置爲0。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章