線性迴歸的簡潔實現
xiaoyao 動手學深度學習 tensorflow 2.1.0
隨着深度學習框架的發展,開發深度學習應用變得越來越便利。實踐中,我們通常可以用比上一節更簡潔的代碼來實現同樣的模型。在本節中,我們將介紹如何使用tensorflow2.1.0推薦的keras接口更方便地實現線性迴歸的訓練。
生成數據集
我們生成與上一節中相同的數據集。其中features
是訓練數據特徵,labels
是標籤。
import tensorflow as tf
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = tf.random.normal(shape=(num_examples, num_inputs), stddev=1)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += tf.random.normal(labels.shape, stddev=0.01)
讀取數據
雖然tensorflow2.1.0對於線性迴歸可以直接擬合,不用再劃分數據集,但我們仍學習一下讀取數據的方法
# shuffle的buffer_size參數應該大於等於樣本數,batch可以指定batch_size的分割大小
from tensorflow import data as tfdata
batch_size = 10
# 將訓練數據的特徵和標籤組合
dataset = tfdata.Dataset.from_tensor_slices((features, labels))
# 隨機讀取小批量
dataset = dataset.shuffle(buffer_size=num_examples)
dataset = dataset.batch(batch_size)
data_iter = iter(dataset)
# 使用iter(dataset)的方式,只能遍歷數據集一次,
for X, y in data_iter:
print(X, y)
break
tf.Tensor(
[[-1.825709 -0.41736308]
[-0.6260657 -0.37043497]
[-1.2240889 -0.4179984 ]
[-2.252687 0.32804537]
[ 0.00852243 -0.11625145]
[ 0.42531878 -0.9496812 ]
[ 0.4022167 -0.07259909]
[-1.0691589 -0.18955724]
[-0.20947874 1.566279 ]
[ 1.7726566 1.5784163 ]], shape=(10, 2), dtype=float32) tf.Tensor(
[ 1.9595385 4.213116 3.1751869 -1.417277 4.620007 8.291456
5.2536917 2.709371 -1.5466335 2.3810005], shape=(10,), dtype=float32)
定義模型
定義模型,tensorflow 2.x推薦使用keras定義網絡,故使用keras定義網絡我們先定義一個模型變量model
,它是一個Sequential
實例。在keras中,Sequential
實例可以看作是一個串聯各個層的容器。
在構造模型時,我們在該容器中依次添加層。當給定輸入數據時,容器中的每一層將依次計算並將輸出作爲下一層的輸入。重要的一點是,在keras中我們無須指定每一層輸入的形狀。
因爲爲線性迴歸,輸入層與輸出層全連接,故定義一層–全連接層keras.layers.Dense()
Keras中初始化參數由kermel_initializer和bias_initializer選項分別設置權重和偏置的初始化方式。
這裏從tensorflow導入initializers模塊,指定權重參數每個元素將在初始化時隨機採樣於均值爲零、標準差爲0.01的正態分佈。偏差參數默認初始化爲零。
RandomNormal(stddev=0.01)指定權重參數每個元素將在初始化時隨機採樣於均值爲0、標準差爲0.01的正態分佈。偏差參數默認會初始化爲零。
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow import initializers as init
model = keras.Sequential() # 看作是串聯各個層的容器
model.add(layers.Dense(1, kernel_initializer=init.RandomNormal(stddev=0.01)))
定義損失函數和優化算法
定義損失函數和優化器:損失函數爲mse,優化器選擇sgd隨機梯度下降
在keras中,定義完模型後,調用compile()
方法可以配置模型的損失函數和優化方法。
定義損失函數只需傳入loss
的參數,keras定義了各種損失函數,並直接使用它提供的平方損失mse
作爲模型的損失函數。
也無須實現小批量隨機梯度下降,只需傳入optimizer
的參數,keras定義了各種優化算法,我們這裏直接指定學習率爲0.03的小批量隨機梯度下降tf.keras.optimizers.SGD(0.03)
爲優化算法
from tensorflow import losses
loss = losses.MeanSquaredError()
from tensorflow.keras import optimizers
trainer = optimizers.SGD(learning_rate=0.03)
loss_history = []
在使用keras訓練模型時,我們通過調用model
實例的fit
函數來迭代模型。fit
函數只需傳入你的輸入x和輸出y,還有epoch遍歷數據的次數,每次更新梯度的大小batch_size, 這裏定義epoch=3,batch_size=10。
使用keras甚至完全不需要去劃分數據集
在使用tensorflow訓練模型的時候,通過調用tensorflow.GradientTape記錄動態圖梯度,執行tape.gradient獲得動態圖中各變量梯度。
通過model.trainable_varialbes找到需要更新的變量,並使用trainer.apply_gradients更新權重,完成一步訓練。
num_epochs = 3
for epoch in range(1, num_epochs + 1):
for (batch, (X, y)) in enumerate(dataset):
with tf.GradientTape() as tape:
l = loss(model(X, training=True), y)
loss_history.append(l.numpy().mean())
grads = tape.gradient(l, model.trainable_variables)
trainer.apply_gradients(zip(grads, model.trainable_variables))
l = loss(model(features), labels)
print('epoch %d, loss: %f' % (epoch, l))
epoch 1, loss: 0.000264
epoch 2, loss: 0.000097
epoch 3, loss: 0.000097
下面我們分別比較學到的模型參數和真實的模型參數。我們可以通過model的get_weights()
來獲得其權重(weight
)和偏差(bias
)。學到的參數和真實的參數很接近。
true_w, model.get_weights()[0]
([2, -3.4],
array([[ 1.9998281],
[-3.3996763]], dtype=float32))
true_b, model.get_weights()[1]
(4.2, array([4.1998463], dtype=float32))
loss_history
[29.501896,
37.45631,
12.249255,
23.72889,
23.883945,
25.070272,
14.251101,
8.442382,
24.766382,
9.228335,
8.04291,
8.583006,
6.9523644,
6.9970107,
5.7393394,
6.7562685,
2.006997,
3.3466537,
3.010506,
1.8910837,
2.9811425,
2.9470952,
2.7346947,
2.5683753,
1.0880806,
0.71038055,
1.3765603,
1.2225089,
1.125397,
1.136457,
1.0656222,
0.6368358,
1.0103394,
0.81613255,
0.45046028,
0.633396,
0.2740888,
0.44052514,
0.20187739,
0.23083887,
0.19622864,
0.17404571,
0.15724395,
0.39956665,
0.13184759,
0.13588975,
0.0413301,
0.062211554,
0.09542455,
0.06948571,
0.121049464,
0.1404176,
0.07027206,
0.02035113,
0.10618506,
0.06540239,
0.03850427,
0.044746242,
0.037409224,
0.037087567,
0.013585197,
0.04274003,
0.020035543,
0.014686924,
0.018439168,
0.030150274,
0.023141002,
0.019083317,
0.012115336,
0.012250148,
0.010110767,
0.00612779,
0.0148302885,
0.0054951767,
0.003688395,
0.0063000335,
0.0067952946,
0.0037225746,
0.0011148332,
0.0016755849,
0.002579968,
0.0022298498,
0.0027520158,
0.0021182017,
0.0010050359,
0.0019038839,
0.0011049738,
0.0013840701,
0.0010081959,
0.0004165701,
0.0009860347,
0.00060588756,
0.00046795295,
0.00030214773,
0.0005622429,
0.0006436542,
0.00032493853,
0.00063880545,
0.00042860032,
0.00018070132,
0.00015794327,
0.00017725705,
0.00026884335,
0.00028985454,
0.0001893751,
8.273552e-05,
8.2549916e-05,
0.00013522906,
6.562472e-05,
0.00011805694,
0.00014822869,
0.00018188413,
0.00010688017,
0.00011095459,
0.00019555617,
0.00019057601,
0.0003080869,
7.3299874e-05,
8.4678955e-05,
0.00011555682,
0.00012923064,
7.315063e-05,
5.8265996e-05,
0.00012395837,
0.00013559048,
9.3044066e-05,
8.4587366e-05,
5.7960708e-05,
5.7924295e-05,
0.00012980713,
9.7370845e-05,
6.330477e-05,
0.00010059988,
7.232769e-05,
0.00017936503,
6.452073e-05,
5.009457e-05,
0.00010594791,
0.00012093749,
0.00013548261,
0.000107912696,
0.0001587457,
6.858254e-05,
0.0001724594,
0.00010172928,
7.6469034e-05,
7.6007054e-05,
7.583733e-05,
9.580182e-05,
5.8986305e-05,
5.4275395e-05,
6.976486e-05,
4.3399854e-05,
0.00014459722,
0.00018001617,
0.00013258224,
0.00031393423,
0.00010372,
5.736463e-05,
9.139093e-05,
9.799221e-05,
8.2846906e-05,
9.64843e-05,
0.00014751268,
8.349354e-05,
5.8543672e-05,
0.00012027039,
0.00011267074,
3.542353e-05,
0.00014143434,
0.00012744889,
0.00015769311,
4.4014298e-05,
0.000116863215,
9.867393e-05,
9.499614e-05,
0.000118109936,
4.329575e-05,
7.521584e-05,
0.0001241296,
4.275844e-05,
8.648134e-05,
0.00011301902,
0.000101929276,
0.00010192163,
6.985559e-05,
0.00010751579,
7.195994e-05,
2.9877838e-05,
8.252472e-05,
0.00021170666,
0.000114028866,
4.07525e-05,
0.00011056512,
0.00015362678,
6.4155414e-05,
0.00010491493,
0.000110198525,
0.0001302041,
0.00013186826,
0.00016527154,
0.00015286378,
6.084417e-05,
5.6655193e-05,
4.8877053e-05,
5.363222e-05,
6.1288825e-05,
5.74289e-05,
0.00012154386,
3.2718228e-05,
6.969248e-05,
0.000104646824,
0.00014144731,
6.1936196e-05,
3.7562757e-05,
7.326159e-05,
0.00010985002,
9.588372e-05,
0.00023255777,
0.00011218952,
0.00014342464,
0.00012717072,
3.6798574e-05,
7.485154e-05,
7.93941e-05,
0.0001249698,
0.00019434367,
0.00011884035,
0.00013018816,
6.532644e-05,
6.15924e-05,
8.129996e-05,
0.00012252374,
0.00014110973,
0.00010313366,
4.4449225e-05,
3.055489e-05,
9.272004e-05,
8.4361076e-05,
9.4692965e-05,
0.00012557449,
7.8463054e-05,
0.00012208376,
8.491871e-05,
6.938853e-05,
0.00012711977,
0.00017110733,
0.00029210007,
0.00015827871,
0.0001660751,
9.0286114e-05,
0.000115873314,
0.00013234252,
6.201891e-05,
2.3510238e-05,
5.5823904e-05,
0.00011468558,
6.126233e-05,
0.00015700776,
0.00016621803,
4.3632343e-05,
9.0545145e-05,
0.00014167516,
0.00010468601,
3.7364236e-05,
0.00013142396,
0.00013766726,
9.6800606e-05,
6.343221e-05,
6.1979656e-05,
0.00013079047,
6.305989e-05,
7.536479e-05,
7.072952e-05,
7.8100755e-05,
0.00015733825,
5.7136553e-05,
0.0001431292,
4.0489856e-05,
9.89647e-05,
3.5244804e-05,
5.200087e-05,
6.809345e-05,
7.249845e-05,
7.157237e-05,
4.426187e-05,
7.577443e-05,
0.00016322176,
0.0002448729,
0.00012856603,
7.970275e-05,
8.254009e-05,
9.36201e-05,
2.938913e-05,
3.1724147e-05,
0.00012240729,
0.00010769217,
7.548153e-05,
0.00014087862,
0.00011540208]
使用tensorflow可以簡潔的實現模型,tensorflow.data模塊提供了有關數據處理的工具,tensorflow.keras.layers模塊定義了大量神經網絡的層,tensorflow.initializers模塊定義了各種初始化方法,tensorflow.optimizers模塊提供了模型的各種優化算法。