前文寫了如何使用tensorflow2.0自定義Layer,本文將講述如何自定義Model,並將前述的Layer應用到本Model中來。
(一)tensorflow2.0 - 自定義layer
(二)tensorflow2.0 - 自定義Model
(三)tensorflow2.0 - 自定義loss function(損失函數)
(四)tensorflow2.0 - 實戰稀疏自動編碼器SAE
自定義模型也比較簡單,只是需要搞清楚Model中各部分的作用及執行流程即可。
由於本例中將使用前文中的自定義Layer,因此先將其代碼貼過來以便查閱,沒看過前文的也沒關係,不影響對自定義模型的理解。
import tensorflow as tf
from tensorflow.keras import *
class SAELayer(layers.Layer):
# 初始化num_outputs,即當前層輸出元素的個數
def __init__(self, num_outputs):
super(SAELayer, self).__init__()
self.num_outputs = num_outputs
# 在第一次調用該Layer的call方法前(自動)調用該函數,可以知道輸入數據的shape
# 根據輸入數據的shape可以初始化權值、bias的矩陣
def build(self, input_shape):
self.kernel = self.add_variable("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
self.bias = self.add_variable("bias",
shape=[self.num_outputs])
def call(self, input):
output = tf.matmul(input, self.kernel) + self.bias
# sigmoid激活函數
output = tf.nn.sigmoid(output)
return output
下面自定義模型了,引入的庫函數見上面代碼的最前面。需要注意,Layer和Model都是類,且都要繼承自某些父類。這裏繼承的是tensorflow.keras.Model
。這裏需要實現兩個方法,即__init__()
和call()
。__init__()
是在創建類的對象時調用的,可以按需傳入一些初始化參數。下例構建的是一個三層模型(輸入層由於體現不出來,所以代碼裏看起來是兩層)。
class SAEModel(Model):
def __init__(self, input_shape, output_shape, hidden_shape=None):
# print("init")
# 隱藏層節點個數默認爲輸入層的3倍
if hidden_shape == None:
hidden_shape = 3 * input_shape
# 調用父類__init__()方法
super(SAEModel, self).__init__()
# 初始化模型使用的layer,layer_1爲前述自定義layer
self.layer_1 = SAELayer(hidden_shape)
# layer_2爲全連接層,採用sigmoid激活函數
# 每層在這裏可以不考慮輸入元素個數,但必須考慮輸出元素個數
# 輸入元素個數可以在call()函數中動態確定
self.layer_2 = layers.Dense(output_shape, activation=tf.nn.sigmoid)
def call(self, input_tensor, training=False):
# 輸入數據
hidden = self.layer_1(input_tensor)
output = self.layer_2(hidden)
return output
到此模型就定義完了,然後可以按照一般的流程使用該模型。
下面只是簡單的使用模型的例子,只羅列出來,參數沒有完善,請按需補充後使用。
input_shape = 5
output_shape = 6
model = SAEModel(input_shape, output_shape)
model.build(input_shape=[None, 5])
model.summary()
model.compile(optimizer=, loss=, metrics=[])
到此自定義Model已經結束了,但是很多時候我們往往需要自定義損失函數,而如果損失函數需要自定義除了預測值和實際值之外的額外參數的話,還需要對model進行修改,這我們將在下一篇文章中討論。
(三)tensorflow2.0 - 自定義loss function(損失函數)
參考文獻: