tensorflow2 中关于自定义层的build() 和 call()一点探究

0x00 先上一段代码

问题: 在自定义网络层的时候,想要搞清楚build()call() 是用来做什么的,为什么能调用成功,不用外部再定义


# coding=utf-8
'''
@ Summary: test call
@ Update:  

@ file:    test.py
@ version: 1.0.0

@ Author:  [email protected]
@ Date:    2020/6/11 下午3:48
'''
from __future__ import absolute_import, division, print_function
import tensorflow as tf
tf.keras.backend.clear_session()
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

class MyLayer(layers.Layer):
   def __init__(self, unit=32):
       super(MyLayer, self).__init__()
       self.unit = unit

   def build(self, input_shape):
       self.weight = self.add_weight(shape=(input_shape[-1], self.unit),
                                     initializer=keras.initializers.RandomNormal(),
                                     trainable=True)
       self.bias = self.add_weight(shape=(self.unit,),
                                   initializer=keras.initializers.Zeros(),
                                   trainable=True)

   def call(self, inputs):
       return tf.matmul(inputs, self.weight) + self.bias

my_layer = MyLayer(3)
x = tf.ones((3,5))
out = my_layer(x)
print(out)

0x01 庖丁解牛1 - init

定义一个类对象


my_layer = MyLayer(3)

ok, 此处没有任何问题

上面是仅调用了MyLayer() 类中的__init__() 方法,获得了self.units = 3 这一个变量

此处尚未调用类中的 build()call() 方法



   def __init__(self, unit=32):

       # 继承,此处不多说,有一个很有意思的是单继承和多继承
       super(MyLayer, self).__init__() 
       self.unit = unit

0x02 庖丁解牛2 – build()

初始化一个输入对象:


x = tf.ones((3,5))

这一步也是没有任何问题,继续往下


out = my_layer(x)

这个地方,问题就来了。

回到最开始的问题:为什么不用外部调用就可以运行build()call()等函数?

回答:在Layer() 类中有一个__call__() 魔法方法(上述两个函数已经被tf集成在该函数下面),会被自动调用,因此不用外部调用,具体怎么个调用过程,请阅读源码

接下来就是对my_layer 输入,输入为x

调用build() 方法:


   def build(self, input_shape):
       self.weight = self.add_weight(shape=(input_shape[-1], self.unit),
                                     initializer=keras.initializers.RandomNormal(),
                                     trainable=True)
       self.bias = self.add_weight(shape=(self.unit,),
                                   initializer=keras.initializers.Zeros(),
                                   trainable=True)

初始化两个可训练的值,分别是权重和偏置,ok,此部分问题解决了

顺带解决另外一个问题:为什么要有build() 方法

回答: build() 可自定义网络的权重的维度,可以根据输入来指定权重的维度,若权重固定,可避免使用build() 方法

另外一个需要注意的地方在于:self.built = True

该参数在build() 运行开始时为False,为了保证先调用build() 方法, 再调用call() 方法

结束时会自动赋值为True,保证build() 方法只被调用一次


class MyLayer(layers.Layer):

    def __init__(self, input_dim=32, unit=32):

        super(MyLayer, self).__init__()

        self.weight = self.add_weight(shape=(input_dim, unit),

                                     initializer=keras.initializers.RandomNormal(),

                                     trainable=True)

        self.bias = self.add_weight(shape=(unit,),

                                   initializer=keras.initializers.Zeros(),

                                   trainable=True)

    

    def call(self, inputs):

        return tf.matmul(inputs, self.weight) + self.bias

0x03 庖丁解牛3 – call()

在调用完了build() 方法之后,获取到了初始化的权重和偏置值,接下来进行正向传播,官网上是说实现逻辑功能函数,我更喜欢说成是前者,更好理解


   def call(self, inputs):
       return tf.matmul(inputs, self.weight) + self.bias

返回该层的输出值,不包含激活函数计算

0x04 最终输出


print(out)

0x05 总结

Layer的子类一般实现如下:

  • init():super(), 初始化所有与输入无关的变量

  • build():用于初始化层内的参数和变量

  • call():定义前向传播

第一次训练先计算Model(x), 然后计算Model(x).build(input),最后计算Model(x).call(input),第二次往后就跳过了中间步骤

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章