線性迴歸的原理及實踐(牛頓法)

至此分類算法在前面都學完了,下面將進行迴歸算法的學習。

迴歸算法和分類算法都屬於監督學習算法,不同的是分類算法中標籤的是一些離散的值,代表不同的類別,而在迴歸算法中,標籤是一些連續的值迴歸算法需要訓練得到樣本特徵到這些連續標籤之間的映射線性迴歸是一類重要的迴歸問題,在線性迴歸中,目標值與特徵之間存在着線性相關的關係。

一、線性迴歸

基本形式:給定由 d 個屬性描述的示例 x=(x_{1};x_{2};...;x_{d}) ,線性模型試圖學得一個通過屬性的線性組合來進行預測的函數,即:

                                                            f(x)=w_{1}x_{1}+w_{2}x_{2}+...+w_{d}x_{d}+b

一般形式爲:

                                                                           f(x)=w^{T}x+b

w 和 b 學得之後,模型就得以確定。

線性迴歸試圖學得一個線性模型以儘可能準確地預測實值輸出標記。

1、當輸入屬性就一個時線性迴歸試圖學得:

                                                         f(x)=wx_{i}+b ,使得   f(x_{i})\simeq y_{i}

    如何確定 w 和 b 那?關鍵是衡量 f 和 y 之間的差別。均方誤差是迴歸任務中最常用的性能度量,因此可以使均方誤差最小。基於均方誤差最小化來進行模型求解的方法稱爲“最小二乘法”。在線性迴歸中,最小二乘法就是找到一條直線,使所有樣本到直線的歐式距離和最小。即:    

2、輸入d個屬性時試圖學得:

二、牛頓法

除了梯度下降法,牛頓法也是機器學習中用的比較多的一種優化算法。牛頓法的基本思想是利用迭代點  x_{k} 處的一階導數 (梯度)和二階導數 ( Hessen 矩陣) 對目標函數進行二次函數近似,然後把二次模型的極小點作爲新的迭代點,並不斷重複這一過程,直至求得滿足精度的近似極小值。牛頓法的速度相當快,而且能高度逼近最優值。牛頓法最突出的優點是收斂速度快,具有局部二階收斂性,其分爲基本牛頓法和全局牛頓法。

1.基本牛頓法

基本牛頓法是基於導數的算法,他每一步的迭代方向都是沿着當前點函數值下降的方向。對於一維的情形,對需要求解的優化函數 f(x),求函數的極值的問題可以轉化爲求導函數 f^{,}(x)=0。對 f(x) 進行泰勒展開到二階,得:

                                                               f(x)=f(x_{k})+{f}'(x_{k})(x-x_{k})+\frac{1}{2}{f}''(x_{k})(x-x_{k})^{2}

對上式求導並令其爲0,則:

                                                                            {f}'(x_{k})+{f}''(x_{k})(x-x_{k})=0

即得到:

                                                                                       x=x_{k}-\frac{{f}'(x_{k})}{{f}''(x_{k})}

這就是牛頓法的更新公式。

基本牛頓法的流程:

2.全局牛頓法

基本牛頓法初始時需要足夠靠近極小點,否則將會導致算法不收斂,這時引入全局牛頓法。

全局牛頓法的流程爲:

3.Armijo 搜索

全局牛頓法是基於 Armijo 的搜索,滿足 Armijo 準則:

給定 \beta \in (0,1) ,\sigma \in (0,0.5),令步長因子 \alpha _{k}=\beta ^{m_{k}} ,其中 m_{k} 是滿足下列不等式的最小非負整數:

三、利用線性迴歸進行預測

# -*- coding: utf-8 -*-
"""
Created on Thu Mar 21 20:43:26 2019

@author: 2018061801
"""

import numpy as np
from math import pow

def load_data(file_path):
    '''導入數據
    input:  file_path(string):訓練數據
    output: feature(mat):特徵
            label(mat):標籤
    '''
    f = open(file_path)
    feature = []
    label = []
    for line in f.readlines():
        feature_tmp = []
        lines = line.strip().split("\t")
        feature_tmp.append(1)  # x0
        for i in range(len(lines) - 1):
            feature_tmp.append(float(lines[i]))
        feature.append(feature_tmp)
        label.append(float(lines[-1]))
    f.close()
    return np.mat(feature), np.mat(label).T

def least_square(feature, label):
    '''最小二乘法
    input:  feature(mat):特徵
            label(mat):標籤
    output: w(mat):迴歸係數
    '''
    w = (feature.T * feature).I * feature.T * label
    return w

def first_derivativ(feature, label, w):
    '''計算一階導函數的值
    input:  feature(mat):特徵
            label(mat):標籤
    output: g(mat):一階導數值
    '''
    m, n = np.shape(feature)
    g = np.mat(np.zeros((n, 1)))
    for i in range(m):
        err = label[i, 0] - feature[i, ] * w
        for j in range(n):
            g[j, ] -= err * feature[i, j]
    return g     

def second_derivative(feature):
    '''計算二階導函數的值
    input:  feature(mat):特徵
    output: G(mat):二階導數值
    '''
    m, n = np.shape(feature)
    G = np.mat(np.zeros((n, n)))
    for i in range(m):
        x_left = feature[i, ].T
        x_right = feature[i, ]
        G += x_left * x_right
    return G

def get_error(feature, label, w):
    '''計算誤差
    input:  feature(mat):特徵
            label(mat):標籤
            w(mat):線性迴歸模型的參數
    output: 損失函數值
    '''
    return (label - feature * w).T * (label - feature * w) / 2

def get_min_m(feature, label, sigma, delta, d, w, g):
    '''計算步長中最小的值m
    input:  feature(mat):特徵
            label(mat):標籤
            sigma(float),delta(float):全局牛頓法的參數
            d(mat):負的一階導數除以二階導數值
            g(mat):一階導數值
    output: m(int):最小m值
    '''
    m = 0
    while True:
        w_new = w + pow(sigma, m) * d
        left = get_error(feature, label , w_new)
        right = get_error(feature, label , w) + delta * pow(sigma, m) * g.T * d
        if left <= right:
            break
        else:
            m += 1
    return m           

def newton(feature, label, iterMax, sigma, delta):
    '''牛頓法
    input:  feature(mat):特徵
            label(mat):標籤
            iterMax(int):最大迭代次數
            sigma(float), delta(float):牛頓法中的參數
    output: w(mat):迴歸係數
    '''
    n = np.shape(feature)[1]
    w = np.mat(np.zeros((n, 1)))
    it = 0
    while it <= iterMax:
        # print it
        g = first_derivativ(feature, label, w)  # 一階導數
        G = second_derivative(feature)  # 二階導數
        d = -G.I * g
        m = get_min_m(feature, label, sigma, delta, d, w, g)  # 得到最小的m
        w = w + pow(sigma, m) * d
        if it % 10 == 0:
            print ("\t---- itration: ", it, " , error: ", get_error(feature, label , w)[0, 0])
        it += 1       
    return w

def save_model(file_name, w):
    '''保存最終的模型
    input:  file_name(string):要保存的文件的名稱
            w(mat):訓練好的線性迴歸模型
    '''
    f_result = open(file_name, "w")
    m, n = np.shape(w)
    for i in range(m):
        w_tmp = []
        for j in range(n):
            w_tmp.append(str(w[i, j]))
        f_result.write("\t".join(w_tmp) + "\n")
    f_result.close()
    

if __name__ == "__main__":
    # 1、導入數據集
    print ("----------- 1.load data ----------")
    feature, label = load_data("D:/anaconda4.3/spyder_work/data2.txt")
    # 2.1、最小二乘求解
    print ("----------- 2.training ----------")
    print ("\t ---------- least_square ----------")
    w_ls = least_square(feature, label)
    # 2.2、牛頓法
    #print ("\t ---------- newton ----------")
    #w_newton = newton(feature, label, 50, 0.1, 0.5)
    # 3、保存最終的結果
    print ("----------- 3.save result ----------")
    save_model("weights", w_ls)
    

使用全局牛頓法的結果:

----------- 1.load data ----------
----------- 2.training ----------
         ---------- newton ----------
        ---- itration:  0  , error:  12.346444091730936
        ---- itration:  10  , error:  0.07017065415130548
        ---- itration:  20  , error:  0.07017065415130548
        ---- itration:  30  , error:  0.07017065415130548
        ---- itration:  40  , error:  0.07017065415130548
        ---- itration:  50  , error:  0.07017065415130548
----------- 3.save result ----------

使用最小二乘法結果:

----------- 1.load data ----------
----------- 2.training ----------
         ---------- least_square ----------
----------- 3.save result ----------

對於使用最小二乘法和全局牛頓法,線性迴歸模型最終得到相同的參數值:

w0=0.0031049944337919275

w1=0.9945024703102509

訓練數據:

0.422285418967358	0.429005468089679
0.548811087562498	0.532492334219154
0.0239703698338769	0.0548126922054728
0.611366447087570	0.585248758640251
0.252719613340340	0.285278442421827
0.265421160090244	0.294045385350157
0.219072519567866	0.236511959989685
0.401861387493983	0.375684767186615
0.208115403157731	0.205472741349725
0.239844960876992	0.272723041971107
0.229903140619694	0.258764519520563
0.430828843674200	0.451484834079148
0.448723772214619	0.481495140915047
0.679206578570127	0.709728292693527
0.647192437871312	0.674913329155585
0.230590652199338	0.259896354413699
0.148270899244940	0.121926223582234
0.636281127313104	0.651125466408602
0.205732203664697	0.198436666800684
0.399610991227472	0.371592207288394
0.636183808150379	0.650870007189679
0.0472145428294716	0.0788533881645130
0.193143008904441	0.167631814179908
0.291582698587355	0.306284442203354
0.411417436503224	0.397395169553808
0.628991936629693	0.630338016215000
0.169284716438347	0.136231147351670
0.660361352093842	0.693483286168938
0.298684102512952	0.308889826735502
0.561017462906972	0.537733929537329
0.384320533934917	0.351006848533773
0.221496695817379	0.242290898136858
0.296223865312843	0.308007240154126
0.492684728153993	0.512377513273133
0.634998145179060	0.647698764185059
0.625657915955206	0.620381266360076
0.0891660562602302	0.0994224254674202
0.0742086102687974	0.0936692472873702
0.626865780781351	0.623967608736031
0.441957883339557	0.472361875114246
0.0157191849402231	0.0406701224083588
0.368674561676324	0.337927787300071
0.627217178624091	0.625017672954424
0.253846304247748	0.286173545973494
0.206928529957397	0.201944099043841
0.180152118526450	0.147499079435076
0.155349951472829	0.125977748839280
0.249693546239582	0.282734942340199
0.417614851667815	0.415091347480293
0.197917834658727	0.177822530291359
0.427647487969975	0.443772210784627
0.578657203647726	0.547737030187470
0.0591439520498619	0.0864158098819885
0.504071412195014	0.516865018298665
0.439522870310511	0.468503499870856
0.340869575847922	0.323646896848778
0.641780686860910	0.664296267194680
0.0525360513031822	0.0824850997758311
0.419126940809084	0.419622744902513
0.179854209370586	0.147136383531176
0.0121642783764464	0.0330953951133111
0.412227568135505	0.399575401149883
0.604841264477251	0.574236148799604
0.402583176912226	0.377072606458986
0.603523505984253	0.572312679990502
0.158762606652888	0.128170576482325
0.427182367147573	0.442574474390365
0.129992272434756	0.113636969780914
0.206035553827662	0.199320192463950
0.00663776019535191	0.0192663492512525
0.477238063323136	0.504856882892435
0.651289847573691	0.681638072890728
0.524660387409372	0.523952681654443
0.532747289113760	0.526671056162354
0.394888276902191	0.364036428720525
0.203313142052076	0.191574879978562
0.186561941455151	0.156262844191082
0.204876351181546	0.195969153772640
0.657841775455899	0.690546398866122
0.178007951093262	0.144965532795643
0.392104389907160	0.360145588903358
0.158809561899624	0.128201991205053
0.298746011795488	0.308911800576876
0.0969730534561077	0.102123708000668
0.246251203302465	0.279566081649905
0.569115068981751	0.541869808520661
0.585272752016451	0.552669159180588
0.630492187715448	0.634815081327384
0.393905947217334	0.362620080614826
0.589081572083947	0.555927065039807
0.483579152818206	0.508204150647193
0.0171145855736783	0.0433778070676218
0.100884656221289	0.103439716366314
0.179709022044687	0.146960905786852
0.187562784860364	0.157818728729414
0.126211524490225	0.112184315667860
0.180790786560970	0.148288782503365
0.200791269966831	0.184854236366562
0.651021486883853	0.681227789754590
0.228673106482575	0.256666270032256
0.208782263150292	0.201734022274195
0.210503575395043	0.228067485736071
0.235895301376771	0.245240124609668
0.425622581490282	0.429764168993690
0.237837661606853	0.244385096841278
0.616955289459391	0.594545102911351
0.0376932276156365	0.0952635834401339
0.203176291567000	0.198030612099198
0.447104292379001	0.461498302305307
0.655886611466805	0.698000466195644
0.394438273213089	0.385519634230031
0.402654548839040	0.348495792061340
0.348759401363818	0.345472506033022
0.511054635769362	0.523696740564166
0.00316332449308267	0.0225122980416280
0.314818282272595	0.323955952269789
0.529065739377704	0.543710285781855
0.0108301689715363	0.0467766821290459
0.407258439427563	0.361599256201253
0.190151926658210	0.140125899091758
0.406026174866894	0.386696192598053
0.0317051763169514	0.0940907024508717
0.189882277107736	0.150579718983628
0.209949766870384	0.183492303743452
0.386129592852040	0.338293568165873
0.273665516040774	0.325379231947505
0.636167150788095	0.628100827832926
0.416613293692747	0.403688605299293
0.494069542440626	0.514373228515084
0.194609675237531	0.170155775899871
0.611898884546681	0.599735753459407
0.153748283237577	0.104443291773365
0.242711804416565	0.297690243571762
0.463156223019373	0.503563329092201
0.0960441400637809	0.0765666023019307
0.620558726023471	0.601270482591644
0.353506511758655	0.326842067996216
0.377410664344231	0.330857148517901
0.178583437819129	0.169600380065018
0.632238411096283	0.646156417339950
0.239280256965016	0.277966249392183
0.451355840070360	0.508200010155423
0.0450423759253215	0.0795679785032915
0.293007592738525	0.313851527339232
0.403511713579458	0.362465555452099
0.0596389824166379	0.0572870188772311
0.0256322708904266	0.0465414970455044
0.555865901075772	0.525529923690178
0.426530196385284	0.470439004166721
0.439215848521695	0.470378673630777
0.179018740264693	0.160527847630221
0.625290218639049	0.640887522129697
0.418410724661046	0.426759508002257
0.634549153440050	0.642461422657716
0.206114431420305	0.207463564791586
0.208116012762849	0.212796289939528
0.0192654880536059	0.0769589775293750
0.160271632281092	0.135640013867957
0.229165208106023	0.270581471939234
0.0103522921134037	0.00284882679711455
0.639277385492713	0.633878717364796
0.221977964704073	0.244950875702800
0.187530818135169	0.160342122941687
0.557219810661823	0.537308123068441
0.641357769311235	0.642198856948541
0.281351013720136	0.288938540729609
0.418774620556121	0.390754314844030
0.0283657133768362	0.0501683412053051
0.0542000361133447	0.0637325840016054
0.243177527351406	0.248837488278470
0.484352451963199	0.517065364596454
0.130032441190971	0.0873818795740706
0.271763748539046	0.295343215432390
0.156834016045304	0.149082996569873
0.418718636951629	0.389778692023851
0.628958675610216	0.627415103272291
0.0249749905201242	0.0326378191743664
0.684659242241295	0.735115810087691
0.424882915843142	0.408021868456478
0.217925167868618	0.231433096036625
0.0547546652403919	0.0881287495056403
0.219256756230378	0.233997915839033
0.411463504258259	0.372008467362171
0.0245806793943106	0.0437564438317034
0.381858587724410	0.352263974674522
0.184685085405949	0.178254802473518
0.594648982458728	0.558051748155136
0.671119229155300	0.708464413032835
0.180395208041911	0.166433113457345
0.316757849993797	0.300024130486124
0.653185585082381	0.662985980247932
0.375015494409235	0.370888752538540
0.439380980609484	0.475148149367430
0.335822265332407	0.326355097208739
0.633184990422054	0.635226744623378
0.176732026339695	0.138950147503156
0.200052880581140	0.179655779303571
0.429693006890599	0.450906533983589
0.253621749753599	0.308404280155269
0.209841851399314	0.204648160393728

 

 

 

訓練數據

參考文獻:趙志勇《python 機器學習算法》(程序)

周志華《機器學習》
 

                                                    

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章