十一、加權線性迴歸案例:預測鮑魚的年齡

加權線性迴歸案例:預測鮑魚的年齡

點擊文章標題即可獲取源代碼和筆記
數據集:https://download.csdn.net/download/weixin_44827418/12553408

1.導入數據集

數據集描述:
在這裏插入圖片描述

import pandas as pd
import numpy as np

abalone = pd.read_table("./datas/abalone.txt",header=None)
abalone.columns=['性別','長度','直徑','高度','整體重量','肉重量','內臟重量','殼重','年齡']
abalone.head()
性別 長度 直徑 高度 整體重量 肉重量 內臟重量 殼重 年齡
0 1 0.455 0.365 0.095 0.5140 0.2245 0.1010 0.150 15
1 1 0.350 0.265 0.090 0.2255 0.0995 0.0485 0.070 7
2 -1 0.530 0.420 0.135 0.6770 0.2565 0.1415 0.210 9
3 1 0.440 0.365 0.125 0.5160 0.2155 0.1140 0.155 10
4 0 0.330 0.255 0.080 0.2050 0.0895 0.0395 0.055 7
abalone.shape
(4177, 9)
abalone.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4177 entries, 0 to 4176
Data columns (total 9 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   性別      4177 non-null   int64  
 1   長度      4177 non-null   float64
 2   直徑      4177 non-null   float64
 3   高度      4177 non-null   float64
 4   整體重量    4177 non-null   float64
 5   肉重量     4177 non-null   float64
 6   內臟重量    4177 non-null   float64
 7   殼重      4177 non-null   float64
 8   年齡      4177 non-null   int64  
dtypes: float64(7), int64(2)
memory usage: 293.8 KB
abalone.describe()
性別 長度 直徑 高度 整體重量 肉重量 內臟重量 殼重 年齡
count 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000
mean 0.052909 0.523992 0.407881 0.139516 0.828742 0.359367 0.180594 0.238831 9.933684
std 0.822240 0.120093 0.099240 0.041827 0.490389 0.221963 0.109614 0.139203 3.224169
min -1.000000 0.075000 0.055000 0.000000 0.002000 0.001000 0.000500 0.001500 1.000000
25% -1.000000 0.450000 0.350000 0.115000 0.441500 0.186000 0.093500 0.130000 8.000000
50% 0.000000 0.545000 0.425000 0.140000 0.799500 0.336000 0.171000 0.234000 9.000000
75% 1.000000 0.615000 0.480000 0.165000 1.153000 0.502000 0.253000 0.329000 11.000000
max 1.000000 0.815000 0.650000 1.130000 2.825500 1.488000 0.760000 1.005000 29.000000

2. 查看數據分佈狀況

import numpy as np
import pandas as pd
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['simhei'] #顯示中文
plt.rcParams['axes.unicode_minus']=False # 用來正常顯示負號  
%matplotlib inline
mpl.cm.rainbow(np.linspace(0,1,10))
array([[5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
       [2.80392157e-01, 3.38158275e-01, 9.85162233e-01, 1.00000000e+00],
       [6.07843137e-02, 6.36474236e-01, 9.41089253e-01, 1.00000000e+00],
       [1.66666667e-01, 8.66025404e-01, 8.66025404e-01, 1.00000000e+00],
       [3.86274510e-01, 9.84086337e-01, 7.67362681e-01, 1.00000000e+00],
       [6.13725490e-01, 9.84086337e-01, 6.41213315e-01, 1.00000000e+00],
       [8.33333333e-01, 8.66025404e-01, 5.00000000e-01, 1.00000000e+00],
       [1.00000000e+00, 6.36474236e-01, 3.38158275e-01, 1.00000000e+00],
       [1.00000000e+00, 3.38158275e-01, 1.71625679e-01, 1.00000000e+00],
       [1.00000000e+00, 1.22464680e-16, 6.12323400e-17, 1.00000000e+00]])
mpl.cm.rainbow(np.linspace(0,1,10))[0]
array([0.5, 0. , 1. , 1. ])
def dataPlot(dataSet):
    m,n = dataSet.shape
    fig = plt.figure(figsize=(8,20),dpi=100)
    colormap = mpl.cm.rainbow(np.linspace(0,1,n))
    for i in range(n):
        fig_ = fig.add_subplot(n,1,i+1)
        plt.scatter(range(m),dataSet.iloc[:,i].values,s=2,c=colormap[i])
        plt.title(dataSet.columns[i])
        plt.tight_layout(pad=1.2) # 調節子圖間的距離
# 運行函數,查看數據分佈:
dataPlot(abalone)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-sUDRrFEr-1593153198969)(output_10_1.png)]

可以從數據分佈散點圖中看出:

1)除“性別”之外,其他數據明顯存在規律性排列

2)“高度”這一特徵中,有兩個異常值

從看到的現象,我們可以採取以下兩種措施:

1) 切分訓練集和測試集時,需要打亂原始數據集來進行隨機挑選

2) 剔除"高度"這一特徵中的異常值

abalone['高度']<0.4
0       True
1       True
2       True
3       True
4       True
        ... 
4172    True
4173    True
4174    True
4175    True
4176    True
Name: 高度, Length: 4177, dtype: bool
aba = abalone.loc[abalone['高度']<0.4,:]
#再次查看數據集的分佈
dataPlot(aba)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-rhcXvPsH-1593153198971)(output_18_1.png)]

2. 切分訓練集和測試集

"""
函數功能:隨機切分訓練集和測試集
參數說明:
    dataSet:原始數據集
    rate:訓練集比例
返回:
    train,test:切分好的訓練集和測試集
"""
def randSplit(dataSet,rate):
    l = list(dataSet.index) # 將原始數據集的索引提取出來,存到列表中
    random.seed(123) # 設置隨機數種子
    random.shuffle(l) # 隨機打亂數據集中的索引
    dataSet.index = l # 把打亂後的索引重新賦值給數據集中的索引,
    # 索引打亂了就相當於打亂了原始數據集中的數據
    m = dataSet.shape[0] # 原始數據集樣本總數
    n = int(m*rate) # 訓練集樣本數量
    train = dataSet.loc[range(n),:] # 從打亂了的原始數據集中提取出訓練集數據
    test = dataSet.loc[range(n,m),:] # 從打亂了的原始數據集中提取出測試集數據
    train.index = range(train.shape[0]) # 重置train訓練數據集中的索引
    test.index = range(test.shape[0]) # 重置test測試數據集中的索引
    dataSet.index = range(dataSet.shape[0]) # 重置原始數據集中的索引
    return train,test
train,test = randSplit(aba,0.8)
#探索訓練集
train.head()
性別 長度 直徑 高度 整體重量 肉重量 內臟重量 殼重 年齡
0 -1 0.590 0.470 0.170 0.9000 0.3550 0.1905 0.2500 11
1 1 0.560 0.450 0.145 0.9355 0.4250 0.1645 0.2725 11
2 -1 0.635 0.535 0.190 1.2420 0.5760 0.2475 0.3900 14
3 1 0.505 0.390 0.115 0.5585 0.2575 0.1190 0.1535 8
4 1 0.510 0.410 0.145 0.7960 0.3865 0.1815 0.1955 8
train.shape
(3340, 9)
abalone.describe()
性別 長度 直徑 高度 整體重量 肉重量 內臟重量 殼重 年齡
count 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000
mean 0.052909 0.523992 0.407881 0.139516 0.828742 0.359367 0.180594 0.238831 9.933684
std 0.822240 0.120093 0.099240 0.041827 0.490389 0.221963 0.109614 0.139203 3.224169
min -1.000000 0.075000 0.055000 0.000000 0.002000 0.001000 0.000500 0.001500 1.000000
25% -1.000000 0.450000 0.350000 0.115000 0.441500 0.186000 0.093500 0.130000 8.000000
50% 0.000000 0.545000 0.425000 0.140000 0.799500 0.336000 0.171000 0.234000 9.000000
75% 1.000000 0.615000 0.480000 0.165000 1.153000 0.502000 0.253000 0.329000 11.000000
max 1.000000 0.815000 0.650000 1.130000 2.825500 1.488000 0.760000 1.005000 29.000000
train.describe() #統計描述
性別 長度 直徑 高度 整體重量 肉重量 內臟重量 殼重 年齡
count 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000
mean 0.060479 0.522754 0.406886 0.138790 0.824906 0.358151 0.179732 0.237158 9.911976
std 0.819021 0.120300 0.099372 0.038441 0.488535 0.222422 0.109036 0.137920 3.223534
min -1.000000 0.075000 0.055000 0.000000 0.002000 0.001000 0.000500 0.001500 1.000000
25% -1.000000 0.450000 0.350000 0.115000 0.439000 0.184375 0.092000 0.130000 8.000000
50% 0.000000 0.540000 0.420000 0.140000 0.796750 0.335500 0.171000 0.232000 9.000000
75% 1.000000 0.615000 0.480000 0.165000 1.147250 0.498500 0.250500 0.325000 11.000000
max 1.000000 0.780000 0.630000 0.250000 2.825500 1.488000 0.760000 1.005000 27.000000
dataPlot(train) #查看訓練集數據分佈
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-sIC8Ac3y-1593153198972)(output_26_1.png)]

#探索測試集
test.head() 
性別 長度 直徑 高度 整體重量 肉重量 內臟重量 殼重 年齡
0 1 0.630 0.470 0.150 1.1355 0.5390 0.2325 0.3115 12
1 -1 0.585 0.445 0.140 0.9130 0.4305 0.2205 0.2530 10
2 -1 0.390 0.290 0.125 0.3055 0.1210 0.0820 0.0900 7
3 1 0.525 0.410 0.130 0.9900 0.3865 0.2430 0.2950 15
4 1 0.625 0.475 0.160 1.0845 0.5005 0.2355 0.3105 10
test.shape 
(835, 9)
test.describe() 
性別 長度 直徑 高度 整體重量 肉重量 內臟重量 殼重 年齡
count 835.000000 835.000000 835.000000 835.000000 835.000000 835.000000 835.000000 835.000000 835.000000
mean 0.022754 0.528808 0.411737 0.140784 0.842714 0.363370 0.183749 0.245320 10.022754
std 0.834341 0.119166 0.098627 0.038664 0.495990 0.218938 0.111510 0.143925 3.230284
min -1.000000 0.130000 0.100000 0.015000 0.013000 0.004500 0.003000 0.004000 3.000000
25% -1.000000 0.450000 0.350000 0.115000 0.458000 0.192000 0.096500 0.132750 8.000000
50% 0.000000 0.550000 0.430000 0.140000 0.810000 0.339000 0.170500 0.235000 10.000000
75% 1.000000 0.620000 0.485000 0.170000 1.177250 0.510750 0.259250 0.337000 11.000000
max 1.000000 0.815000 0.650000 0.250000 2.555000 1.145500 0.590000 0.815000 29.000000
dataPlot(test)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-MjIwyXmw-1593153198974)(output_30_1.png)]

3.構建輔助函數

'''
函數功能:輸入DF數據集(最後一列爲標籤),返回特徵矩陣和標籤矩陣
'''
def get_Mat(dataSet):
    xMat = np.mat(dataSet.iloc[:,:-1].values)
    yMat = np.mat(dataSet.iloc[:,-1].values).T
    return xMat,yMat

'''
函數功能:數據集可視化
'''
def plotShow(dataSet):
    xMat,yMat = get_Mat(dataSet)
    plt.scatter(xMat.A[:,1],yMat.A,c='b',s=5)
    plt.show()

'''
函數功能:計算迴歸係數
參數說明:
    dataSet:原始數據集
返回:
    ws:迴歸係數
'''
def standRegres(dataSet):
    xMat,yMat = get_Mat(dataSet)
    xTx = xMat.T * xMat
    if np.linalg.det(xTx) == 0:
        print('矩陣爲奇異矩陣,無法求逆!')
        return
    ws = xTx.I*(xMat.T*yMat) # xTx.I ,用來求逆矩陣
    return ws
"""
函數功能:計算誤差平方和SSE
參數說明:
    dataSet:真實值
    regres:求迴歸係數的函數
返回:
    SSE:誤差平方和
"""
def sseCal(dataSet, regres):
    xMat,yMat = get_Mat(dataSet)
    ws = regres(dataSet)
    yHat = xMat*ws
    sse = ((yMat.A.flatten() - yHat.A.flatten())**2).sum()#  
    return sse

以ex0數據集爲例,查看函數運行結果:

ex0 = pd.read_table("./datas/ex0.txt",header=None)
ex0.head()
0 1 2
0 1.0 0.067732 3.176513
1 1.0 0.427810 3.816464
2 1.0 0.995731 4.550095
3 1.0 0.738336 4.256571
4 1.0 0.981083 4.560815
#簡單線性迴歸的SSE
sseCal(ex0, standRegres)
1.3552490816814902

構建相關係數R2計算函數

"""
函數功能:計算相關係數R2
"""
def rSquare(dataSet,regres):
    xMat,yMat=get_Mat(dataSet)
    sse = sseCal(dataSet,regres)
    sst = ((yMat.A-yMat.mean())**2).sum()#  
    r2 = 1 - sse / sst
    return r2

同樣以ex0數據集爲例,查看函數運行結果:

#簡單線性迴歸的R2
rSquare(ex0, standRegres)
0.9731300889856916
'''
函數功能:計算局部加權線性迴歸的預測值
參數說明:
    testMat:測試集
    xMat:訓練集的特徵矩陣
    yMat:訓練集的標籤矩陣
    返回:
        yHat:函數預測值
'''
def LWLR(testMat,xMat,yMat,k=1.0):
    n = testMat.shape[0] # 測試數據集行數
    m = xMat.shape[0] # 訓練集特徵矩陣行數
    weights = np.mat(np.eye(m)) # 用單位矩陣來初始化權重矩陣,
    yHat = np.zeros(n) # 用0矩陣來初始化預測值矩陣
    for i in range(n):
        for j in range(m):
            diffMat = testMat[i] - xMat[j]
            weights[j,j] = np.exp(diffMat*diffMat.T / (-2*k**2))
        xTx = xMat.T*(weights*xMat)
        if np.linalg.det(xTx) == 0:
            print('矩陣爲奇異矩陣,無法求逆')
            return
        ws = xTx.I*(xMat.T*(weights*yMat))
        yHat[i] = testMat[i] * ws
    return ws,yHat

4.構建加權線性模型

因爲數據量太大,計算速度極慢,所以此處選擇訓練集的前100個數據作爲訓練集,測試集的前100個數據作爲測試集。

"""
函數功能:繪製不同k取值下,訓練集和測試集的SSE曲線
"""
def ssePlot(train,test):
    X0,Y0 = get_Mat(train)
    X1,Y1 =get_Mat(test)
    train_sse = []
    test_sse = []
    for k in np.arange(0.2,10,0.5):
        ws1,yHat1 = LWLR(X0[:99],X0[:99],Y0[:99],k) 
        sse1 = ((Y0[:99].A.T - yHat1)**2).sum() 
        train_sse.append(sse1)
        
        ws2,yHat2 = LWLR(X1[:99],X0[:99],Y0[:99],k) 
        sse2 = ((Y1[:99].A.T - yHat2)**2).sum() 
        test_sse.append(sse2)
        
    plt.figure(figsize=(20,8),dpi=100)
    plt.plot(np.arange(0.2,10,0.5),train_sse,color='b')#     
    plt.plot(np.arange(0.2,10,0.5),test_sse,color='r') 
    plt.xlabel('不同k取值')
    plt.ylabel('SSE')
    plt.legend(['train_sse','test_sse'])

運行結果:

ssePlot(train,test)

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-BXGhyRcs-1593153198975)(output_47_0.png)]

這個圖的解讀應該是這樣的:從右往左看,當K取較大值時,模型比較穩定,隨着K值的減小,訓練集的SSE開始逐漸減小,當K取到2左右,訓練集的SSE與測試集的SSE相等,當K繼續減小時,訓練集的SSE也越來越小,也就是說,模型在訓練集上的表現越來越好,但是,模型在測試集上的表現卻越來越差了,這就說明模型開始出現過擬合了。其實,這個圖與前面不同k值的結果圖是吻合的,K=1.0,
0.01, 0.003這三張圖也表明隨着K的減小,模型會逐漸出現過擬合。所以這裏可以看出,K在2左右的取值最佳。

我們再將K=2帶入局部線性迴歸模型中,然後查看預測結果:

train,test = randSplit(aba,0.8) # 隨機切分原始數據集,得到訓練集和測試集
trainX,trainY = get_Mat(train) # 將切分好的訓練集分成特徵矩陣和標籤矩陣
testX,testY = get_Mat(test) # 將切分好的測試集分成特徵矩陣和標籤矩陣
ws0,yHat0 = LWLR(testX,trainX,trainY,k=2)

繪製真實值與預測值之間的關係圖

y=testY.A.flatten()
plt.scatter(y,yHat0,c='b',s=5); # ;等效於plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-y9Wfstwl-1593153198976)(output_52_0.png)]

通過上圖可知,橫座標爲真實值,縱座標爲預測值,形成的圖像爲呈現一個“喇叭形”,隨着橫座標真實值逐漸變大,縱座標預測值也越來越大,說明隨着真實值的增加,預測值偏差越來越大

封裝一個函數來計算SSE和R方,方便後續調用

"""
函數功能:計算加權線性迴歸的SSE和R方
"""
def LWLR_pre(dataSet):
    train,test = randSplit(dataSet,0.8)#      
    trainX,trainY = get_Mat(train)
    testX,testY = get_Mat(test)
    ws,yHat = LWLR(testX,trainX,trainY,k=2)#     
    sse = ((testY.A.T - yHat)**2).sum()#     
    sst = ((testY.A-testY.mean())**2).sum() #     
    r2 = 1 - sse / sst
    return sse,r2

查看模型預測結果

LWLR_pre(aba)
(4152.777097646255, 0.5228101340130846)

從結果可以看出,SSE達4000+,相關係數只有0.52,模型效果並不是很好。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章