基於tensorflow的線性迴歸

import tensorflow as tf

# 初始化變量和模型參數,定義訓練閉環中的運算
W = tf.Variable(tf.zeros([2, 1]), name="weights")
b = tf.Variable(0., name="bias")


def inference(X):  # 計算推斷模型在數據X上的輸出,並將結果保存
    return tf.matmul(X, W) + b


def loss(X, Y):  # 依據訓練數據X和期望輸出Y計算損失
    Y_predicted = inference(X)
    return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))


def inputs():  # 讀取或生成訓練數據X及其期望輸出Y
    weight_age = [[84, 46], [73, 20], [65, 52], [70, 30], [76, 57], [69, 25], [63, 28], [72, 36], [79, 57],
                  [75, 44], [27, 24], [89, 31], [65, 52], [57, 23], [59, 60], [69, 48], [60, 34], [79, 51],
                  [75, 50], [82, 34], [59, 46], [67, 23], [85, 37], [55, 40], [63, 30]]
    blood_fat_content = [354, 190, 405, 263, 451, 302, 288, 385, 402, 365, 209, 346, 254, 395, 434, 220, 374, 308, 220,
                         311, 181, 274, 303, 244]
    return tf.to_float(weight_age), tf.to_float(blood_fat_content)


def train(total_loss):  # 依據計算的總損失訓練或調整模型參數
    learning_rate = 0.0000001
    return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)


def evaluate(sess, X, Y):  # 對訓練得到的模型進行評估
    print(sess.run(inference([[80., 25.]])))
    print(sess.run(inference([[65., 25.]])))


# 在一個會話對象中啓動數據流圖,搭建流程
with tf.Session() as sess:
    tf.initialize_all_variables().run()
    X, Y = inputs()

    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    # 實際的訓練迭代次數
    training_steps = 1000
    for step in range(training_steps):
        sess.run([train_op])
        # 處於調試和學習的目的,查看損失在訓練過程中的遞減情況
        if step % 10 == 0:
            print("loss:", sess.run([total_loss]))
    evaluate(sess, X, Y)
    coord.request_stop()
    coord.join(threads)
    sess.close()

對於這種簡單的模型,將採用總平方誤差,即模型對每個訓練樣本的預測值與期望輸出之差的平方的總和。從代數角度看,這個損失函數實際上是預測的輸出向量與期望向量之間歐氏距離的平方。對於2D數據集,總平方誤差對應於每個數據點在垂直方向上到所預測的迴歸直線的距離的平方總和。這種損失函數也稱爲L2範數或L2損失函數。這裏之所以採用平方,是爲了避免計算平方根,因爲對於最小化損失這個目標,有無平方並無本質區別,但有平方可以節省一定的計算量。

數據來源:http://people.sc.fsu.edu/~jburkardt/datasets/regression/x09.txt

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章