線性迴歸tensorflow2.1.0簡潔實現

線性迴歸的簡潔實現

xiaoyao 動手學深度學習 tensorflow 2.1.0

隨着深度學習框架的發展,開發深度學習應用變得越來越便利。實踐中,我們通常可以用比上一節更簡潔的代碼來實現同樣的模型。在本節中,我們將介紹如何使用tensorflow2.1.0推薦的keras接口更方便地實現線性迴歸的訓練。

生成數據集

我們生成與上一節中相同的數據集。其中features是訓練數據特徵,labels是標籤。

import tensorflow as tf

num_inputs = 2
num_examples = 1000

true_w = [2, -3.4]
true_b = 4.2

features = tf.random.normal(shape=(num_examples, num_inputs), stddev=1)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += tf.random.normal(labels.shape, stddev=0.01)

讀取數據

雖然tensorflow2.1.0對於線性迴歸可以直接擬合,不用再劃分數據集,但我們仍學習一下讀取數據的方法

# shuffle的buffer_size參數應該大於等於樣本數,batch可以指定batch_size的分割大小
from tensorflow import data as tfdata

batch_size = 10
# 將訓練數據的特徵和標籤組合
dataset = tfdata.Dataset.from_tensor_slices((features, labels))
# 隨機讀取小批量
dataset = dataset.shuffle(buffer_size=num_examples)
dataset = dataset.batch(batch_size)
data_iter = iter(dataset)
# 使用iter(dataset)的方式,只能遍歷數據集一次,
for X, y in data_iter:
    print(X, y)
    break
tf.Tensor(
[[-1.825709   -0.41736308]
 [-0.6260657  -0.37043497]
 [-1.2240889  -0.4179984 ]
 [-2.252687    0.32804537]
 [ 0.00852243 -0.11625145]
 [ 0.42531878 -0.9496812 ]
 [ 0.4022167  -0.07259909]
 [-1.0691589  -0.18955724]
 [-0.20947874  1.566279  ]
 [ 1.7726566   1.5784163 ]], shape=(10, 2), dtype=float32) tf.Tensor(
[ 1.9595385  4.213116   3.1751869 -1.417277   4.620007   8.291456
  5.2536917  2.709371  -1.5466335  2.3810005], shape=(10,), dtype=float32)

定義模型

定義模型,tensorflow 2.x推薦使用keras定義網絡,故使用keras定義網絡我們先定義一個模型變量model,它是一個Sequential實例。在keras中,Sequential實例可以看作是一個串聯各個層的容器。

在構造模型時,我們在該容器中依次添加層。當給定輸入數據時,容器中的每一層將依次計算並將輸出作爲下一層的輸入。重要的一點是,在keras中我們無須指定每一層輸入的形狀。
因爲爲線性迴歸,輸入層與輸出層全連接,故定義一層–全連接層keras.layers.Dense()

Keras中初始化參數由kermel_initializer和bias_initializer選項分別設置權重和偏置的初始化方式。

這裏從tensorflow導入initializers模塊,指定權重參數每個元素將在初始化時隨機採樣於均值爲零、標準差爲0.01的正態分佈。偏差參數默認初始化爲零。

RandomNormal(stddev=0.01)指定權重參數每個元素將在初始化時隨機採樣於均值爲0、標準差爲0.01的正態分佈。偏差參數默認會初始化爲零。

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow import initializers as init
model = keras.Sequential() # 看作是串聯各個層的容器
model.add(layers.Dense(1, kernel_initializer=init.RandomNormal(stddev=0.01)))

定義損失函數和優化算法

定義損失函數和優化器:損失函數爲mse,優化器選擇sgd隨機梯度下降

在keras中,定義完模型後,調用compile()方法可以配置模型的損失函數和優化方法。

定義損失函數只需傳入loss的參數,keras定義了各種損失函數,並直接使用它提供的平方損失mse作爲模型的損失函數。

也無須實現小批量隨機梯度下降,只需傳入optimizer的參數,keras定義了各種優化算法,我們這裏直接指定學習率爲0.03的小批量隨機梯度下降tf.keras.optimizers.SGD(0.03)爲優化算法

from tensorflow import losses
loss = losses.MeanSquaredError()
from tensorflow.keras import optimizers
trainer = optimizers.SGD(learning_rate=0.03)
loss_history = []

在使用keras訓練模型時,我們通過調用model實例的fit函數來迭代模型。fit函數只需傳入你的輸入x和輸出y,還有epoch遍歷數據的次數,每次更新梯度的大小batch_size, 這裏定義epoch=3,batch_size=10。
使用keras甚至完全不需要去劃分數據集

在使用tensorflow訓練模型的時候,通過調用tensorflow.GradientTape記錄動態圖梯度,執行tape.gradient獲得動態圖中各變量梯度。

通過model.trainable_varialbes找到需要更新的變量,並使用trainer.apply_gradients更新權重,完成一步訓練。

num_epochs = 3
for epoch in range(1, num_epochs + 1):
    for (batch, (X, y)) in enumerate(dataset):
        with tf.GradientTape() as tape:
            l = loss(model(X, training=True), y)
        
        loss_history.append(l.numpy().mean())
        grads = tape.gradient(l, model.trainable_variables)
        trainer.apply_gradients(zip(grads, model.trainable_variables))
    
    l = loss(model(features), labels)
    print('epoch %d, loss: %f' % (epoch, l))
    
epoch 1, loss: 0.000264
epoch 2, loss: 0.000097
epoch 3, loss: 0.000097

下面我們分別比較學到的模型參數和真實的模型參數。我們可以通過model的get_weights()來獲得其權重(weight)和偏差(bias)。學到的參數和真實的參數很接近。

true_w, model.get_weights()[0]
([2, -3.4],
 array([[ 1.9998281],
        [-3.3996763]], dtype=float32))
true_b, model.get_weights()[1]
(4.2, array([4.1998463], dtype=float32))
loss_history
[29.501896,
 37.45631,
 12.249255,
 23.72889,
 23.883945,
 25.070272,
 14.251101,
 8.442382,
 24.766382,
 9.228335,
 8.04291,
 8.583006,
 6.9523644,
 6.9970107,
 5.7393394,
 6.7562685,
 2.006997,
 3.3466537,
 3.010506,
 1.8910837,
 2.9811425,
 2.9470952,
 2.7346947,
 2.5683753,
 1.0880806,
 0.71038055,
 1.3765603,
 1.2225089,
 1.125397,
 1.136457,
 1.0656222,
 0.6368358,
 1.0103394,
 0.81613255,
 0.45046028,
 0.633396,
 0.2740888,
 0.44052514,
 0.20187739,
 0.23083887,
 0.19622864,
 0.17404571,
 0.15724395,
 0.39956665,
 0.13184759,
 0.13588975,
 0.0413301,
 0.062211554,
 0.09542455,
 0.06948571,
 0.121049464,
 0.1404176,
 0.07027206,
 0.02035113,
 0.10618506,
 0.06540239,
 0.03850427,
 0.044746242,
 0.037409224,
 0.037087567,
 0.013585197,
 0.04274003,
 0.020035543,
 0.014686924,
 0.018439168,
 0.030150274,
 0.023141002,
 0.019083317,
 0.012115336,
 0.012250148,
 0.010110767,
 0.00612779,
 0.0148302885,
 0.0054951767,
 0.003688395,
 0.0063000335,
 0.0067952946,
 0.0037225746,
 0.0011148332,
 0.0016755849,
 0.002579968,
 0.0022298498,
 0.0027520158,
 0.0021182017,
 0.0010050359,
 0.0019038839,
 0.0011049738,
 0.0013840701,
 0.0010081959,
 0.0004165701,
 0.0009860347,
 0.00060588756,
 0.00046795295,
 0.00030214773,
 0.0005622429,
 0.0006436542,
 0.00032493853,
 0.00063880545,
 0.00042860032,
 0.00018070132,
 0.00015794327,
 0.00017725705,
 0.00026884335,
 0.00028985454,
 0.0001893751,
 8.273552e-05,
 8.2549916e-05,
 0.00013522906,
 6.562472e-05,
 0.00011805694,
 0.00014822869,
 0.00018188413,
 0.00010688017,
 0.00011095459,
 0.00019555617,
 0.00019057601,
 0.0003080869,
 7.3299874e-05,
 8.4678955e-05,
 0.00011555682,
 0.00012923064,
 7.315063e-05,
 5.8265996e-05,
 0.00012395837,
 0.00013559048,
 9.3044066e-05,
 8.4587366e-05,
 5.7960708e-05,
 5.7924295e-05,
 0.00012980713,
 9.7370845e-05,
 6.330477e-05,
 0.00010059988,
 7.232769e-05,
 0.00017936503,
 6.452073e-05,
 5.009457e-05,
 0.00010594791,
 0.00012093749,
 0.00013548261,
 0.000107912696,
 0.0001587457,
 6.858254e-05,
 0.0001724594,
 0.00010172928,
 7.6469034e-05,
 7.6007054e-05,
 7.583733e-05,
 9.580182e-05,
 5.8986305e-05,
 5.4275395e-05,
 6.976486e-05,
 4.3399854e-05,
 0.00014459722,
 0.00018001617,
 0.00013258224,
 0.00031393423,
 0.00010372,
 5.736463e-05,
 9.139093e-05,
 9.799221e-05,
 8.2846906e-05,
 9.64843e-05,
 0.00014751268,
 8.349354e-05,
 5.8543672e-05,
 0.00012027039,
 0.00011267074,
 3.542353e-05,
 0.00014143434,
 0.00012744889,
 0.00015769311,
 4.4014298e-05,
 0.000116863215,
 9.867393e-05,
 9.499614e-05,
 0.000118109936,
 4.329575e-05,
 7.521584e-05,
 0.0001241296,
 4.275844e-05,
 8.648134e-05,
 0.00011301902,
 0.000101929276,
 0.00010192163,
 6.985559e-05,
 0.00010751579,
 7.195994e-05,
 2.9877838e-05,
 8.252472e-05,
 0.00021170666,
 0.000114028866,
 4.07525e-05,
 0.00011056512,
 0.00015362678,
 6.4155414e-05,
 0.00010491493,
 0.000110198525,
 0.0001302041,
 0.00013186826,
 0.00016527154,
 0.00015286378,
 6.084417e-05,
 5.6655193e-05,
 4.8877053e-05,
 5.363222e-05,
 6.1288825e-05,
 5.74289e-05,
 0.00012154386,
 3.2718228e-05,
 6.969248e-05,
 0.000104646824,
 0.00014144731,
 6.1936196e-05,
 3.7562757e-05,
 7.326159e-05,
 0.00010985002,
 9.588372e-05,
 0.00023255777,
 0.00011218952,
 0.00014342464,
 0.00012717072,
 3.6798574e-05,
 7.485154e-05,
 7.93941e-05,
 0.0001249698,
 0.00019434367,
 0.00011884035,
 0.00013018816,
 6.532644e-05,
 6.15924e-05,
 8.129996e-05,
 0.00012252374,
 0.00014110973,
 0.00010313366,
 4.4449225e-05,
 3.055489e-05,
 9.272004e-05,
 8.4361076e-05,
 9.4692965e-05,
 0.00012557449,
 7.8463054e-05,
 0.00012208376,
 8.491871e-05,
 6.938853e-05,
 0.00012711977,
 0.00017110733,
 0.00029210007,
 0.00015827871,
 0.0001660751,
 9.0286114e-05,
 0.000115873314,
 0.00013234252,
 6.201891e-05,
 2.3510238e-05,
 5.5823904e-05,
 0.00011468558,
 6.126233e-05,
 0.00015700776,
 0.00016621803,
 4.3632343e-05,
 9.0545145e-05,
 0.00014167516,
 0.00010468601,
 3.7364236e-05,
 0.00013142396,
 0.00013766726,
 9.6800606e-05,
 6.343221e-05,
 6.1979656e-05,
 0.00013079047,
 6.305989e-05,
 7.536479e-05,
 7.072952e-05,
 7.8100755e-05,
 0.00015733825,
 5.7136553e-05,
 0.0001431292,
 4.0489856e-05,
 9.89647e-05,
 3.5244804e-05,
 5.200087e-05,
 6.809345e-05,
 7.249845e-05,
 7.157237e-05,
 4.426187e-05,
 7.577443e-05,
 0.00016322176,
 0.0002448729,
 0.00012856603,
 7.970275e-05,
 8.254009e-05,
 9.36201e-05,
 2.938913e-05,
 3.1724147e-05,
 0.00012240729,
 0.00010769217,
 7.548153e-05,
 0.00014087862,
 0.00011540208]

使用tensorflow可以簡潔的實現模型,tensorflow.data模塊提供了有關數據處理的工具,tensorflow.keras.layers模塊定義了大量神經網絡的層,tensorflow.initializers模塊定義了各種初始化方法,tensorflow.optimizers模塊提供了模型的各種優化算法。


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章