機器學習梯度下降法舉例

現在有一組數據,plot之後如下圖

x = np.array([1,3,4,7,9,10,11,18])
y = np.array([5,12,12,20,32,30,33,60])

在這裏插入圖片描述
現在問:當x等於8的時候,y的取值是多少?
首先看這個圖,具有線性,於是我們假設先用一條直線擬合(即一次函數)
採用梯度下降法
我們的hypothesis爲
在這裏插入圖片描述
有兩個參數,使用梯度下降法參數更新爲
在這裏插入圖片描述
用python寫code

import numpy as np
import matplotlib.pyplot as plt

x = np.array([1,3,4,7,9,10,11,18])
y = np.array([5,12,12,20,32,30,33,60])

def Gradient_descent(x):
    alpha = 0.001; m = len(x)
    theta0,theta1 = np.random.randint(1,6,2) #initialize two parameters with int between [1,6)
    while(True):
        h = theta0 + theta1*x
        J = 1/(2*m)*np.sum(((h-y)**2))
        theta0 = theta0 - alpha*1/m*np.sum((h-y))
        theta1 = theta1 - alpha*1/m*np.sum(((h-y)*x))
        print(J,theta0,theta1)
        if J <=5:
            return theta0, theta1
            break

if __name__ == "__main__":
    theta0,theta1 = Gradient_descent(x)
    #plot data
    x1 = np.linspace(np.min(x),np.max(x)+1)
    y1 = theta0 + theta1*x1
    plt.scatter(x,y,marker='x')
    plt.plot(x1,y1)
    plt.show()

結果如下
在這裏插入圖片描述

43.375 3.00675 2.083625
36.64715182128906 3.012834703125 2.159869203125
31.055311490596658 3.0183128984472654 2.2293845909140626
26.407644730247434 3.02323818189537 2.2927655520599455
22.544711758531772 3.0276594149910023 2.3505539698782667
19.333996925479305 3.03162114306322 2.4032438603746296
16.665362113113904 3.0351639765197063 2.45128560060768
14.447267906771756 3.038324938438401 2.4950897835393393
12.603632886217634 3.0411377814545903 2.5350307323665024
11.071223279127338 3.0436332766557497 2.571449704413933
9.797483416187205 3.045839476956834 2.604657812010998
8.738732550924968 3.0477819572102907 2.634938685352499
7.858666177326327 3.0494840331059296 2.6625509001354555
7.127110425582591 3.050966960734257 2.687730190750377
6.5189867994023265 3.0522501185213633 2.710691467970093
6.013451735119135 3.0533511730925773 2.731630658405858
5.593181460948169 3.0542862304845384 2.7507263814749403
5.243777619977898 3.0550699739999385 2.768141478233133
4.9532732638769845 3.055715789884853 2.784024405157705

第一列是平方誤差,可以看到誤差在逐漸減小。
注意:code中的alpha是learning rate,這裏我取0.001;此處不能取很大,比如取1則結果會越來越差,平方誤差會越來越大。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章