wide & Deep模型
稀疏特徵:離散數值(可叉乘)
密集特徵:向量表達
示例
函數式API
#deep
input = keras.layers.Input(shape = x_train.shape[1:])
hidden1 = keras.layers.Dense(30,activation = 'relu')(input)
hidden2 = keras.layers.Dense(30,activation = 'relu')(hidden1)
#wide
#直接有輸入值進行傳遞
#拼接
concat = keras.layers.concatenate([input,hidden2])
#輸出
output = keras.layers.Dense(1)(concat)
#固化模型
concat = keras.models.Model(inputs = [input],
outputs = [output])
子類API
class WideDeepModel(keras.models.Model):
def __init__(self):
super(WideDeepModel, self).__init__()
"""定義模型的層次"""
self.hidden1_layer = keras.layers.Dense(30,activation="relu")
self.hidden2_layer = keras.layers.Dense(30,activation="relu")
self.output_layer = keras.layers.Dense(1)
def call(self, input):
"""完成模型的正向計算"""
hidden1 = self.hidden1_layer(input)
hidden2 = self.hidden2_layer(hidden1)
concat = keras.layers.concatenate([input,hidden2])
output = self.output_layer(concat)
return output
#model = WideDeepModel() #這種方法或者下一種方法都可以
model = keras.models.Sequential([
WideDeepModel(),
])
model.build(input_shape=(None, 8)) #(樣本的數目,輸入的fetch的數目)
#這樣只是定義層仍需要compile和fit
多輸入
# 多輸入
input_wide = keras.layers.Input(shape = [5])
input_deep = keras.layers.Input(shape = [6])
hidden1 = keras.layers.Dense(30,activation="relu")(input_deep)
hidden2 = keras.layers.Dense(30,activation="relu")(hidden1)
concat = keras.layers.concatenate([input_wide, hidden2])
output = keras.layers.Dense(1)(concat)
model = keras.models.Model(inputs = [input_wide, input_deep],
outputs = [output])
model.summary()
model.compile(loss = "mean_squared_error",optimizer = "sgd")
callbacks = [keras.callbacks.EarlyStopping(
patience=5,min_delta=1e-2)]
擴展多輸出
# 多輸出
input_wide = keras.layers.Input(shape = [5])
input_deep = keras.layers.Input(shape = [6])
hidden1 = keras.layers.Dense(30,activation="relu")(input_deep)
hidden2 = keras.layers.Dense(30,activation="relu")(hidden1)
concat = keras.layers.concatenate([input_wide, hidden2])
output = keras.layers.Dense(1)(concat)
output2 = keras.layers.Dense(1)(hidden2)
model = keras.models.Model(inputs = [input_wide, input_deep],
outputs = [output,output2])
model.summary()
model.compile(loss = "mean_squared_error",optimizer = "sgd")
callbacks = [keras.callbacks.EarlyStopping(
patience=5,min_delta=1e-2)]