十二、案例:加利福尼亞房屋價值數據集(多元線性迴歸)& Lasso & 嶺迴歸 & 分箱處理非線性問題 & 多項式迴歸

案例:加利福尼亞房屋價值數據集(線性迴歸)& Lasso & 嶺迴歸 & 分箱處理非線性問題

點擊標題即可獲取文章源代碼和筆記

1. 導入需要的模塊和庫

from sklearn.linear_model import LinearRegression as LR
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.datasets import fetch_california_housing as fch # 加利福尼亞房屋價值數據集
import pandas as pd

2. 導入數據,探索數據

housevalue = fch()
X = pd.DataFrame(housevalue.data)
X
0 1 2 3 4 5 6 7
0 8.3252 41.0 6.984127 1.023810 322.0 2.555556 37.88 -122.23
1 8.3014 21.0 6.238137 0.971880 2401.0 2.109842 37.86 -122.22
2 7.2574 52.0 8.288136 1.073446 496.0 2.802260 37.85 -122.24
3 5.6431 52.0 5.817352 1.073059 558.0 2.547945 37.85 -122.25
4 3.8462 52.0 6.281853 1.081081 565.0 2.181467 37.85 -122.25
... ... ... ... ... ... ... ... ...
20635 1.5603 25.0 5.045455 1.133333 845.0 2.560606 39.48 -121.09
20636 2.5568 18.0 6.114035 1.315789 356.0 3.122807 39.49 -121.21
20637 1.7000 17.0 5.205543 1.120092 1007.0 2.325635 39.43 -121.22
20638 1.8672 18.0 5.329513 1.171920 741.0 2.123209 39.43 -121.32
20639 2.3886 16.0 5.254717 1.162264 1387.0 2.616981 39.37 -121.24

20640 rows × 8 columns

X.shape
(20640, 8)
y = housevalue.target
y
array([4.526, 3.585, 3.521, ..., 0.923, 0.847, 0.894])
y.shape
(20640,)
X.head()
0 1 2 3 4 5 6 7
0 8.3252 41.0 6.984127 1.023810 322.0 2.555556 37.88 -122.23
1 8.3014 21.0 6.238137 0.971880 2401.0 2.109842 37.86 -122.22
2 7.2574 52.0 8.288136 1.073446 496.0 2.802260 37.85 -122.24
3 5.6431 52.0 5.817352 1.073059 558.0 2.547945 37.85 -122.25
4 3.8462 52.0 6.281853 1.081081 565.0 2.181467 37.85 -122.25
housevalue.feature_names
['MedInc',
 'HouseAge',
 'AveRooms',
 'AveBedrms',
 'Population',
 'AveOccup',
 'Latitude',
 'Longitude']
y.min()
0.14999
y.max()
5.00001
X.columns = housevalue.feature_names
X.head()
MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude Longitude
0 8.3252 41.0 6.984127 1.023810 322.0 2.555556 37.88 -122.23
1 8.3014 21.0 6.238137 0.971880 2401.0 2.109842 37.86 -122.22
2 7.2574 52.0 8.288136 1.073446 496.0 2.802260 37.85 -122.24
3 5.6431 52.0 5.817352 1.073059 558.0 2.547945 37.85 -122.25
4 3.8462 52.0 6.281853 1.081081 565.0 2.181467 37.85 -122.25
  • MedInc:該街區住戶的收入中位數
  • HouseAge:該街區房屋使用年代的中位數
  • AveRooms:該街區平均的房間數目
  • AveBedrms:該街區平均的臥室數目
  • Population:街區人口
  • AveOccup:平均入住率
  • Latitude:街區的緯度
  • Longitude:街區的經度

3. 分訓練集和測試集

Xtrain,Xtest,Ytrain,Ytest = train_test_split(X,y,test_size=0.3,random_state=420)

# 重置特徵矩陣的索引
for i in [Xtrain,Xtest]:
    i.index = range(i.shape[0])
    
Xtrain.shape
(14448, 8)

4.建模

reg = LR().fit(Xtrain,Ytrain) # 實例化+訓練模型
yhat = reg.predict(Xtest)
yhat
array([1.51384887, 0.46566247, 2.2567733 , ..., 2.11885803, 1.76968187,
       0.73219077])
yhat.min()
-0.6528439725036108
yhat.max()
7.1461982142708536

5. 探索建好的模型

reg.coef_ # w,係數向量
array([ 4.37358931e-01,  1.02112683e-02, -1.07807216e-01,  6.26433828e-01,
        5.21612535e-07, -3.34850965e-03, -4.13095938e-01, -4.26210954e-01])
zip(Xtrain.columns,reg.coef_)
<zip at 0x1a1ddc21308>
[*zip(Xtrain.columns,reg.coef_)]
[('MedInc', 0.43735893059684006),
 ('HouseAge', 0.010211268294493883),
 ('AveRooms', -0.10780721617317668),
 ('AveBedrms', 0.6264338275363759),
 ('Population', 5.216125353348089e-07),
 ('AveOccup', -0.003348509646333704),
 ('Latitude', -0.4130959378947717),
 ('Longitude', -0.4262109536208464)]
reg.intercept_ # 截距
-36.25689322920381

3 迴歸類的模型評估指標

3.1 是否預測了正確的數值

均方誤差,本質是在RSS的基礎上除以了樣本總量,得到了每個樣本量上的平均誤差。有了平均誤差,我們就可以將平均誤差和我們的標籤的取值範圍在一起比較,以此獲得一個較爲可靠的評估依據。在sklearn當中,我們有兩種方式調用這個評估指標,一種是使用sklearn專用的模型評估模塊metrics裏的類mean_squared_error,另一種是調用交叉驗證的類cross_val_score並使用裏面的scoring參數來設置使用均方誤差。

from sklearn.metrics import mean_squared_error as MSE
MSE(yhat,Ytest)
0.5309012639324571
Ytest.mean()
2.0819292877906976
# 10折交叉驗證
cross_val_score(reg,X,y,cv=10,scoring="neg_mean_squared_error")
array([-0.48922052, -0.43335865, -0.8864377 , -0.39091641, -0.7479731 ,
       -0.52980278, -0.28798456, -0.77326441, -0.64305557, -0.3275106 ])
cross_val_score(reg,X,y,cv=10,scoring="neg_mean_squared_error").mean()
-0.5509524296956592
cross_val_score(reg,X,y,cv=10,scoring="neg_mean_absolute_error").mean()
-0.5445214393266326
# 查看scoring參數的可選值有哪些
import sklearn
sorted(sklearn.metrics.SCORERS.keys())
['accuracy',
 'adjusted_mutual_info_score',
 'adjusted_rand_score',
 'average_precision',
 'balanced_accuracy',
 'completeness_score',
 'explained_variance',
 'f1',
 'f1_macro',
 'f1_micro',
 'f1_samples',
 'f1_weighted',
 'fowlkes_mallows_score',
 'homogeneity_score',
 'jaccard',
 'jaccard_macro',
 'jaccard_micro',
 'jaccard_samples',
 'jaccard_weighted',
 'max_error',
 'mutual_info_score',
 'neg_brier_score',
 'neg_log_loss',
 'neg_mean_absolute_error',
 'neg_mean_gamma_deviance',
 'neg_mean_poisson_deviance',
 'neg_mean_squared_error',
 'neg_mean_squared_log_error',
 'neg_median_absolute_error',
 'neg_root_mean_squared_error',
 'normalized_mutual_info_score',
 'precision',
 'precision_macro',
 'precision_micro',
 'precision_samples',
 'precision_weighted',
 'r2',
 'recall',
 'recall_macro',
 'recall_micro',
 'recall_samples',
 'recall_weighted',
 'roc_auc',
 'roc_auc_ovo',
 'roc_auc_ovo_weighted',
 'roc_auc_ovr',
 'roc_auc_ovr_weighted',
 'v_measure_score']

3.2 是否擬合了足夠的信息

在R平方中,分子是真實值和預測值之差的差值,也就是我們的模型沒有捕獲到的信息總量,分母是真實標籤所帶的信息量,所以其衡量的是1 - 我們的模型沒有捕獲到的信息量佔真實標籤中所帶的信息量的比例,所以, R平方越接近1越好。

R平方可以使用三種方式來調用,一種是直接從metrics中導入r2_score,輸入預測值和真實值後打分。第二種是直接從線性迴歸LinearRegression的接口score來進行調用。第三種是在交叉驗證中,輸入"r2"來調用。

#調用R2
from sklearn.metrics import r2_score 
#使用shift tab鍵來檢查究竟哪個值先進行輸入
r2_score(Ytest,yhat)
0.6043668160178817
r2 = reg.score(Xtest,Ytest)
r2
0.6043668160178817
cross_val_score(reg,X,y,cv=10,scoring="r2").mean()
0.5110068610524557

我們觀察到,我們在加利福尼亞房屋價值數據集上的MSE其實不是一個很大的數(0.5),但我們的 不高,這證明我們的模型比較好地擬合了一部分數據的數值,卻沒有能正確擬合數據的分佈。讓我們與繪圖來看看,究竟是不是這樣一回事。我們可以繪製一張圖上的兩條曲線,一條曲線是我們的真實標籤Ytest,另一條曲線是我們的預測結果yhat,兩條曲線的交疊越多,我們的模型擬合就越好。

import matplotlib.pyplot as plt 
sorted(Ytest)
[0.14999,
 0.14999,
 0.225,
 0.325,
 0.35,
 0.375,
 0.388,
 0.392,
 0.394,
 0.396,
 0.4,
 0.404,
 0.409,
 0.41,
 0.43,
 0.435,
 0.437,
 0.439,
 0.44,
 0.44,
 0.444,
 0.446,
 0.45,
 0.45,
 0.45,
 0.45,
 0.455,
 0.455,
 0.455,
 0.456,
 0.462,
 0.463,
 0.471,
 0.475,
 0.478,
 0.478,
 0.481,
 0.481,
 0.483,
 0.483,
 0.485,
 0.485,
 0.488,
 0.489,
 0.49,
 0.492,
 0.494,
 0.494,
 0.494,
 0.495,
 0.496,
 0.5,
 0.5,
 0.504,
 0.505,
 0.506,
 0.506,
 0.508,
 0.508,
 0.51,
 0.516,
 0.519,
 0.52,
 0.521,
 0.523,
 0.523,
 0.525,
 0.525,
 0.525,
 0.525,
 0.525,
 0.527,
 0.527,
 0.528,
 0.529,
 0.53,
 0.531,
 0.532,
 0.534,
 0.535,
 0.535,
 0.535,
 0.538,
 0.538,
 0.539,
 0.539,
 0.539,
 0.541,
 0.541,
 0.542,
 0.542,
 0.542,
 0.543,
 0.543,
 0.544,
 0.544,
 0.546,
 0.547,
 0.55,
 0.55,
 0.55,
 0.55,
 0.55,
 0.55,
 0.55,
 0.55,
 0.551,
 0.553,
 0.553,
 0.553,
 0.554,
 0.554,
 0.554,
 0.555,
 0.556,
 0.556,
 0.557,
 0.558,
 0.558,
 0.559,
 0.559,
 0.559,
 0.559,
 0.56,
 0.56,
 0.562,
 0.566,
 0.567,
 0.567,
 0.567,
 0.567,
 0.567,
 0.568,
 0.57,
 0.571,
 0.572,
 0.574,
 0.574,
 0.575,
 0.575,
 0.575,
 0.575,
 0.576,
 0.577,
 0.577,
 0.577,
 0.578,
 0.579,
 0.579,
 0.579,
 0.58,
 0.58,
 0.58,
 0.58,
 0.58,
 0.58,
 0.581,
 0.581,
 0.581,
 0.581,
 0.582,
 0.583,
 0.583,
 0.583,
 0.583,
 0.584,
 0.586,
 0.586,
 0.587,
 0.588,
 0.588,
 0.59,
 0.59,
 0.59,
 0.59,
 0.591,
 0.591,
 0.593,
 0.593,
 0.594,
 0.594,
 0.594,
 0.594,
 0.595,
 0.596,
 0.596,
 0.597,
 0.598,
 0.598,
 0.6,
 0.6,
 0.6,
 0.602,
 0.602,
 0.603,
 0.604,
 0.604,
 0.604,
 0.605,
 0.606,
 0.606,
 0.608,
 0.608,
 0.608,
 0.609,
 0.609,
 0.611,
 0.612,
 0.612,
 0.613,
 0.613,
 0.613,
 0.614,
 0.615,
 0.616,
 0.616,
 0.616,
 0.616,
 0.618,
 0.618,
 0.618,
 0.619,
 0.619,
 0.62,
 0.62,
 0.62,
 0.62,
 0.62,
 0.62,
 0.62,
 0.62,
 0.621,
 0.621,
 0.621,
 0.622,
 0.623,
 0.625,
 0.625,
 0.625,
 0.627,
 0.627,
 0.628,
 0.628,
 0.629,
 0.63,
 0.63,
 0.63,
 0.63,
 0.631,
 0.631,
 0.632,
 0.632,
 0.633,
 0.633,
 0.633,
 0.634,
 0.634,
 0.635,
 0.635,
 0.635,
 0.635,
 0.635,
 0.637,
 0.637,
 0.637,
 0.637,
 0.638,
 0.639,
 0.643,
 0.644,
 0.644,
 0.646,
 0.646,
 0.646,
 0.646,
 0.647,
 0.647,
 0.647,
 0.648,
 0.65,
 0.65,
 0.65,
 0.652,
 0.652,
 0.654,
 0.654,
 0.654,
 0.655,
 0.656,
 0.656,
 0.656,
 0.656,
 0.657,
 0.658,
 0.658,
 0.659,
 0.659,
 0.659,
 0.659,
 0.659,
 0.66,
 0.661,
 0.661,
 0.662,
 0.662,
 0.663,
 0.664,
 0.664,
 0.664,
 0.668,
 0.669,
 0.669,
 0.67,
 0.67,
 0.67,
 0.67,
 0.67,
 0.67,
 0.672,
 0.672,
 0.672,
 0.673,
 0.673,
 0.674,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.675,
 0.676,
 0.676,
 0.677,
 0.678,
 0.68,
 0.68,
 0.681,
 0.682,
 0.682,
 0.682,
 0.682,
 0.683,
 0.683,
 0.683,
 0.684,
 0.684,
 0.685,
 0.685,
 0.685,
 0.685,
 0.686,
 0.686,
 0.687,
 0.688,
 0.689,
 0.689,
 0.689,
 0.69,
 0.69,
 0.691,
 0.691,
 0.692,
 0.693,
 0.694,
 0.694,
 0.694,
 0.694,
 0.694,
 0.695,
 0.695,
 0.695,
 0.696,
 0.696,
 0.697,
 0.698,
 0.699,
 0.699,
 0.7,
 0.7,
 0.7,
 0.7,
 0.7,
 0.7,
 0.701,
 0.701,
 0.701,
 0.702,
 0.702,
 0.703,
 0.704,
 0.704,
 0.705,
 0.705,
 0.706,
 0.707,
 0.707,
 0.707,
 0.708,
 0.709,
 0.71,
 0.71,
 0.71,
 0.711,
 0.712,
 0.712,
 0.713,
 0.713,
 0.713,
 0.714,
 0.715,
 0.716,
 0.718,
 0.719,
 0.72,
 0.72,
 0.72,
 0.721,
 0.722,
 0.723,
 0.723,
 0.723,
 0.723,
 0.723,
 0.725,
 0.725,
 0.727,
 0.727,
 0.728,
 0.729,
 0.729,
 0.73,
 0.73,
 0.73,
 0.73,
 0.73,
 0.731,
 0.731,
 0.731,
 0.731,
 0.732,
 0.733,
 0.733,
 0.734,
 0.735,
 0.735,
 0.737,
 0.738,
 0.738,
 0.738,
 0.74,
 0.74,
 0.74,
 0.741,
 0.741,
 0.741,
 0.743,
 0.746,
 0.746,
 0.747,
 0.748,
 0.749,
 0.75,
 0.75,
 0.75,
 0.75,
 0.75,
 0.75,
 0.75,
 0.752,
 0.752,
 0.754,
 0.756,
 0.756,
 0.757,
 0.759,
 0.759,
 0.759,
 0.76,
 0.76,
 0.761,
 0.762,
 0.762,
 0.762,
 0.762,
 0.763,
 0.764,
 0.764,
 0.765,
 0.766,
 0.768,
 0.769,
 0.77,
 0.771,
 0.771,
 0.771,
 0.772,
 0.772,
 0.773,
 0.774,
 0.774,
 0.775,
 0.777,
 0.777,
 0.779,
 0.78,
 0.78,
 0.78,
 0.781,
 0.783,
 0.783,
 0.785,
 0.786,
 0.786,
 0.786,
 0.786,
 0.788,
 0.788,
 0.788,
 0.788,
 0.788,
 0.79,
 0.79,
 0.79,
 0.792,
 0.792,
 0.792,
 0.795,
 0.795,
 0.795,
 0.797,
 0.797,
 0.798,
 0.799,
 0.8,
 0.801,
 0.802,
 0.803,
 0.804,
 0.804,
 0.804,
 0.806,
 0.806,
 0.808,
 0.808,
 0.808,
 0.809,
 0.81,
 0.81,
 0.811,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.813,
 0.814,
 0.814,
 0.816,
 0.817,
 0.817,
 0.817,
 0.821,
 0.821,
 0.821,
 0.823,
 0.823,
 0.824,
 0.825,
 0.825,
 0.825,
 0.826,
 0.827,
 0.827,
 0.828,
 0.828,
 0.828,
 0.83,
 0.83,
 0.83,
 0.831,
 0.831,
 0.831,
 0.832,
 0.832,
 0.832,
 0.833,
 0.833,
 0.834,
 0.835,
 0.835,
 0.836,
 0.836,
 0.837,
 0.838,
 0.839,
 0.839,
 0.839,
 0.839,
 0.84,
 0.841,
 0.842,
 0.842,
 0.842,
 0.843,
 0.843,
 0.844,
 0.844,
 0.844,
 0.845,
 0.845,
 0.845,
 0.845,
 0.846,
 0.846,
 0.846,
 0.846,
 0.847,
 0.847,
 0.847,
 0.847,
 0.847,
 0.847,
 0.848,
 0.849,
 0.849,
 0.85,
 0.85,
 0.85,
 0.851,
 0.851,
 0.851,
 0.851,
 0.852,
 0.853,
 0.853,
 0.854,
 0.854,
 0.854,
 0.855,
 0.855,
 0.855,
 0.855,
 0.856,
 0.857,
 0.857,
 0.857,
 0.857,
 0.857,
 0.858,
 0.859,
 0.859,
 0.859,
 0.859,
 0.859,
 0.861,
 0.862,
 0.863,
 0.863,
 0.863,
 0.864,
 0.864,
 0.864,
 0.864,
 0.865,
 0.865,
 0.865,
 0.866,
 0.867,
 0.867,
 0.868,
 0.869,
 0.869,
 0.869,
 0.869,
 0.87,
 0.87,
 0.871,
 0.871,
 0.872,
 0.872,
 0.872,
 0.873,
 0.874,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.875,
 0.876,
 0.876,
 0.877,
 0.877,
 0.878,
 0.878,
 0.878,
 0.879,
 0.879,
 0.879,
 0.88,
 0.88,
 0.881,
 0.881,
 0.882,
 0.882,
 0.882,
 0.882,
 0.883,
 0.883,
 0.883,
 0.883,
 0.883,
 0.883,
 0.884,
 0.885,
 0.885,
 0.886,
 0.887,
 0.887,
 0.887,
 0.888,
 0.888,
 0.888,
 0.889,
 0.889,
 0.889,
 0.889,
 0.889,
 0.89,
 0.891,
 0.892,
 0.892,
 0.892,
 0.893,
 0.893,
 0.894,
 0.895,
 0.896,
 0.896,
 0.897,
 0.897,
 0.898,
 0.898,
 0.899,
 0.9,
 0.9,
 0.9,
 0.901,
 0.901,
 0.901,
 0.902,
 0.903,
 0.903,
 0.904,
 0.904,
 0.904,
 0.905,
 0.905,
 0.905,
 0.905,
 0.906,
 0.906,
 0.906,
 0.906,
 0.907,
 0.907,
 0.908,
 0.911,
 0.911,
 0.912,
 0.914,
 0.915,
 0.915,
 0.916,
 0.916,
 0.917,
 0.917,
 0.917,
 0.917,
 0.918,
 0.918,
 0.918,
 0.919,
 0.919,
 0.919,
 0.92,
 0.92,
 0.922,
 0.922,
 0.922,
 0.922,
 0.922,
 0.924,
 0.925,
 0.925,
 0.925,
 0.925,
 0.926,
 0.926,
 0.926,
 0.926,
 0.926,
 0.926,
 0.926,
 0.926,
 0.926,
 0.926,
 0.927,
 0.927,
 0.927,
 0.927,
 0.928,
 0.928,
 0.928,
 0.928,
 0.928,
 0.929,
 0.93,
 0.93,
 0.931,
 0.931,
 0.931,
 0.931,
 0.931,
 0.931,
 0.932,
 0.932,
 0.932,
 0.932,
 0.933,
 0.933,
 0.933,
 0.934,
 0.934,
 0.934,
 0.934,
 0.934,
 0.935,
 0.935,
 0.935,
 0.936,
 0.936,
 0.936,
 0.936,
 0.938,
 0.938,
 0.938,
 0.938,
 0.938,
 0.938,
 0.938,
 0.938,
 0.938,
 0.938,
 0.938,
 0.939,
 0.939,
 0.94,
 0.94,
 0.942,
 0.942,
 0.943,
 0.943,
 0.944,
 0.944,
 0.944,
 0.945,
 0.945,
 0.946,
 0.946,
 0.946,
 0.946,
 0.946,
 0.946,
 0.946,
 0.947,
 0.947,
 0.948,
 0.948,
 0.948,
 0.949,
 0.949,
 0.95,
 0.95,
 0.95,
 0.95,
 0.95,
 0.951,
 0.952,
 0.952,
 0.953,
 0.953,
 0.953,
 0.953,
 0.954,
 0.955,
 0.955,
 0.955,
 0.955,
 0.955,
 0.956,
 0.957,
 0.957,
 0.957,
 0.958,
 0.958,
 0.958,
 0.958,
 0.958,
 0.958,
 0.96,
 0.96,
 0.96,
 0.96,
 0.96,
 0.96,
 0.961,
 0.961,
 0.962,
 0.962,
 0.962,
 0.962,
 0.962,
 0.962,
 0.962,
 0.963,
 0.964,
 0.964,
 0.964,
 0.964,
 0.965,
 0.965,
 0.965,
 0.966,
 0.966,
 0.966,
 0.967,
 0.967,
 0.967,
 0.968,
 0.968,
 0.969,
 0.969,
 0.969,
 0.969,
 0.97,
 0.971,
 0.972,
 0.972,
 0.973,
 0.973,
 0.973,
 0.974,
 0.974,
 0.974,
 0.974,
 0.976,
 0.976,
 0.976,
 0.976,
 0.977,
 0.977,
 0.978,
 0.978,
 0.978,
 0.979,
 0.979,
 ...]
plt.plot(range(len(Ytest)),sorted(Ytest),c='black',label='Data')
plt.plot(range(len(yhat)),sorted(yhat),c='red',label='Predict')
plt.legend()
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-1ZfyEU16-1593261098346)(output_44_0.png)]

可見,雖然我們的大部分數據被擬合得比較好,但是圖像的開頭和結尾處卻又着較大的擬合誤差。如果我們在圖像右側分佈着更多的數據,我們的模型就會越來越偏離我們真正的標籤。這種結果類似於我們前面提到的,雖然在有限的數據集上將數值預測正確了,但卻沒有正確擬合數據的分佈,如果有更多的數據進入我們的模型,那數據標籤被預測錯誤的可能性是非常大的。

當我們的 顯示爲負的時候,這證明我們的模型對我們的數據的擬合非常糟糕,模型完全不能使用。所有,一個負的 是合理的。當然了,現實應用中,如果你發現你的線性迴歸模型出現了負的 ,不代表你就要接受他了,首先檢查你的建模過程和數據處理過程是否正確,也許你已經傷害了數據本身,也許你的建模過程是存在bug的。如果是集成模型的迴歸,檢查你的弱評估器的數量是否不足,隨機森林,提升樹這些模型在只有兩三棵樹的時候很容易出現負的 。如果你檢查了所有的代碼,也確定了你的預處理沒有問題,但你的 也還是負的,那這就證明,線性迴歸模型不適合你的數據,試試看其他的算法吧。

4.2.1 嶺迴歸解決多重共線性問題

和線性迴歸相比,嶺迴歸的參數稍微多了那麼一點點,但是真正核心的參數就是我們的正則項的係數 ,其他的參數是當我們希望使用最小二乘法之外的求解方法求解嶺迴歸的時候才需要的,通常我們完全不會去觸碰這些參數。所以大家只需要瞭解 的用法就可以了。
之前我們在加利佛尼亞房屋價值數據集上使用線性迴歸,得出的結果大概是訓練集上的擬合程度是60%,測試集上的擬合程度也是60%左右,那這個很低的擬合程度是不是由多重共線性造成的呢?在統計學中,我們會通過VIF或者各種檢驗來判斷數據是否存在共線性,然而在機器學習中,我們可以使用模型來判斷——如果一個數據集在嶺迴歸中使用各種正則化參數取值下模型表現沒有明顯上升(比如出現持平或者下降),則說明數據沒有多重共線性,頂多是特徵之間有一些相關性。反之,如果一個數據集在嶺迴歸的各種正則化參數取值下表現出明顯的上升趨勢,則說明數據存在多重共線性。

接下來,我們就在加利佛尼亞房屋價值數據集上來驗證一下這個說法:

import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge, LinearRegression, Lasso 
from sklearn.model_selection import train_test_split as TTS 
from sklearn.datasets import fetch_california_housing as fch 
import matplotlib.pyplot as plt
housevalue = fch()
X = pd.DataFrame(housevalue.data)
y = housevalue.target
X.columns = ["住戶收入中位數","房屋使用年代中位數","平均房間數目","平均臥室數目","街區人口","平均入住率","街區的緯度","街區的經度"]
X.head()
住戶收入中位數 房屋使用年代中位數 平均房間數目 平均臥室數目 街區人口 平均入住率 街區的緯度 街區的經度
0 8.3252 41.0 6.984127 1.023810 322.0 2.555556 37.88 -122.23
1 8.3014 21.0 6.238137 0.971880 2401.0 2.109842 37.86 -122.22
2 7.2574 52.0 8.288136 1.073446 496.0 2.802260 37.85 -122.24
3 5.6431 52.0 5.817352 1.073059 558.0 2.547945 37.85 -122.25
4 3.8462 52.0 6.281853 1.081081 565.0 2.181467 37.85 -122.25
Xtrain,Xtest,Ytrain,Ytest = TTS(X,y,test_size=0.3,random_state=420)
#數據集索引恢復
for i in [Xtrain,Xtest]:
    i.index = range(i.shape[0])
#使用嶺迴歸來進行建模
reg = Ridge(alpha=1).fit(Xtrain,Ytrain)
reg.score(Xtest,Ytest)
0.6043610352312276
#交叉驗證下,與線性迴歸相比,嶺迴歸的結果如何變化?
alpharange = np.arange(1,1001,100)
ridge, lr = [], []
for alpha in alpharange:
    reg = Ridge(alpha=alpha)
    linear = LinearRegression()
    regs = cross_val_score(reg,X,y,cv=5,scoring = "r2").mean()#    
    linears = cross_val_score(linear,X,y,cv=5,scoring = "r2").mean()#     
    ridge.append(regs)
    lr.append(linears)
    
plt.plot(alpharange,ridge,color="red",label="Ridge") 
plt.plot(alpharange,lr,color="orange",label="LR") 
plt.title("Mean")
plt.legend()
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-caGD7YJW-1593261098349)(output_55_0.png)]

#細化一下學習曲線
#交叉驗證下,與線性迴歸相比,嶺迴歸的結果如何變化?
alpharange = np.arange(1,201,10)
ridge, lr = [], []
for alpha in alpharange:
    reg = Ridge(alpha=alpha)
    linear = LinearRegression()
    regs = cross_val_score(reg,X,y,cv=5,scoring = "r2").mean()#    
    linears = cross_val_score(linear,X,y,cv=5,scoring = "r2").mean()#     
    ridge.append(regs)
    lr.append(linears)
    
plt.plot(alpharange,ridge,color="red",label="Ridge") 
plt.plot(alpharange,lr,color="orange",label="LR") 
plt.title("Mean")
plt.legend()
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-a9aOekt3-1593261098351)(output_56_0.png)]

可以看出,加利佛尼亞數據集上,嶺迴歸的結果輕微上升,隨後驟降。可以說,加利佛尼亞房屋價值數據集帶有很輕微的一部分共線性,這種共線性被正則化參數 消除後,模型的效果提升了一點點,但是對於整個模型而言是杯水車薪。在過了控制多重共線性的點後,模型的效果飛速下降,顯然是正則化的程度太重,擠佔了參數 本來的估計空間。從這個結果可以看出,加利佛尼亞數據集的核心問題不在於多重共線性,嶺迴歸不能夠提升模型表現。

另外,在正則化參數逐漸增大的過程中,我們可以觀察一下模型的方差如何變化:

#模型方差如何變化?
alpharange = np.arange(1,1001,100) 
ridge, lr = [], []
for alpha in alpharange:
    reg = Ridge(alpha=alpha)
    linear = LinearRegression()
    varR = cross_val_score(reg,X,y,cv=5,scoring="r2").var()#  
    varLR = cross_val_score(linear,X,y,cv=5,scoring="r2").var()#     
    ridge.append(varR)
    lr.append(varLR)
    
plt.plot(alpharange,ridge,color="red",label="Ridge") 
plt.plot(alpharange,lr,color="orange",label="LR") 
plt.title("Variance")
plt.legend()
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-zFaka5we-1593261098352)(output_59_0.png)]

可以發現,模型的方差上升快速,不過方差的值本身很小,其變化不超過 上升部分的1/3,因此只要噪聲的狀況維持恆定,模型的泛化誤差可能還是一定程度上降低了的。雖然嶺迴歸和Lasso不是設計來提升模型表現,而是專注於解決多重共線性問題的,但當 在一定範圍內變動的時候,消除多重共線性也許能夠一定程度上提高模型的泛化能力。
但是泛化能力畢竟沒有直接衡量的指標,因此我們往往只能夠通過觀察模型的準確性指標和方差來大致評判模型的泛化能力是否提高。來看看多重共線性更爲明顯一些的情況:

from sklearn.datasets import load_boston
from sklearn.model_selection import cross_val_score
X = load_boston().data
y = load_boston().target
Xtrain,Xtest,Ytrain,Ytest = TTS(X,y,test_size=0.3,random_state=420)
#先查看方差的變化
alpharange = np.arange(1,1001,100)
ridge, lr = [], []
for alpha in alpharange:
    reg = Ridge(alpha=alpha)
    linear = LinearRegression()
    varR = cross_val_score(reg,X,y,cv=5,scoring="r2").var()#     
    varLR = cross_val_score(linear,X,y,cv=5,scoring="r2").var()#     
    ridge.append(varR)
    lr.append(varLR)
    
plt.plot(alpharange,ridge,color="red",label="Ridge") 
plt.plot(alpharange,lr,color="orange",label="LR") 
plt.title("Variance")
plt.legend()
plt.show()    

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-XsYiTaak-1593261098353)(output_62_0.png)]

#查看R2的變化
alpharange = np.arange(1,1001,100)
ridge, lr = [], []
for alpha in alpharange:
    reg = Ridge(alpha=alpha)
    linear = LinearRegression()
    regs = cross_val_score(reg,X,y,cv=5,scoring = "r2").mean()#     
    linears = cross_val_score(linear,X,y,cv=5,scoring = "r2").mean()#     
    ridge.append(regs)
    lr.append(linears)
    
plt.plot(alpharange,ridge,color="red",label="Ridge") 
plt.plot(alpharange,lr,color="orange",label="LR") 
plt.title("Mean")
plt.legend()
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-hpGbQAac-1593261098354)(output_63_0.png)]

#細化學習曲線
alpharange = np.arange(100,300,10)
ridge, lr = [], []
for alpha in alpharange:
    reg = Ridge(alpha=alpha)
    #linear = LinearRegression()
    regs = cross_val_score(reg,X,y,cv=5,scoring = "r2").mean()
    #linears = cross_val_score(linear,X,y,cv=5,scoring = "r2").mean()     
    ridge.append(regs)
    #lr.append(linears)
plt.plot(alpharange,ridge,color="red",label="Ridge") 
#plt.plot(alpharange,lr,color="orange",label="LR") 
plt.title("Mean")
plt.legend() 
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-KUaaqVRe-1593261098355)(output_64_0.png)]

可以發現,比起加利佛尼亞房屋價值數據集,波士頓房價數據集的方差降低明顯,偏差也降低明顯,可見使用嶺迴歸還是起到了一定的作用,模型的泛化能力是有可能會上升的。
遺憾的是,沒有人會希望自己獲取的數據中存在多重共線性,因此發佈到scikit-learn或者kaggle上的數據基本都經過一定的多重共線性的處理的,要找出絕對具有多重共線性的數據非常困難,也就無法給大家展示嶺迴歸在實際數據中大顯身手的模樣。我們也許可以找出具有一些相關性的數據,但是大家如果去嘗試就會發現,基本上如果我們使用嶺迴歸或者Lasso,那模型的效果都是會降低的,很難升高,這恐怕也是嶺迴歸和Lasso一定程度上被機器學習領域冷遇的原因。

4.2.3 選取最佳的正則化參數取值

以正則化參數爲橫座標,線性模型求解的係數 爲縱座標的圖像,其中每一條彩色的線都是一個係數 。其目標是建立正則化參數與係數 之間的直接關係,以此來觀察正則化參數的變化如何影響了係數 的擬合。嶺跡圖認爲,線條交叉越多,則說明特徵之間的多重共線性越高。我們應該選擇係數較爲平穩的喇叭口所對應的 取值作爲最佳的正則化參數的取值。繪製嶺跡圖的方法非常簡單,代碼如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model
#創造10*10的希爾伯特矩陣
X = 1. / (np.arange(1, 11) + np.arange(0, 10)[:, np.newaxis]) 
y = np.ones(10)
X
array([[1.        , 0.5       , 0.33333333, 0.25      , 0.2       ,
        0.16666667, 0.14285714, 0.125     , 0.11111111, 0.1       ],
       [0.5       , 0.33333333, 0.25      , 0.2       , 0.16666667,
        0.14285714, 0.125     , 0.11111111, 0.1       , 0.09090909],
       [0.33333333, 0.25      , 0.2       , 0.16666667, 0.14285714,
        0.125     , 0.11111111, 0.1       , 0.09090909, 0.08333333],
       [0.25      , 0.2       , 0.16666667, 0.14285714, 0.125     ,
        0.11111111, 0.1       , 0.09090909, 0.08333333, 0.07692308],
       [0.2       , 0.16666667, 0.14285714, 0.125     , 0.11111111,
        0.1       , 0.09090909, 0.08333333, 0.07692308, 0.07142857],
       [0.16666667, 0.14285714, 0.125     , 0.11111111, 0.1       ,
        0.09090909, 0.08333333, 0.07692308, 0.07142857, 0.06666667],
       [0.14285714, 0.125     , 0.11111111, 0.1       , 0.09090909,
        0.08333333, 0.07692308, 0.07142857, 0.06666667, 0.0625    ],
       [0.125     , 0.11111111, 0.1       , 0.09090909, 0.08333333,
        0.07692308, 0.07142857, 0.06666667, 0.0625    , 0.05882353],
       [0.11111111, 0.1       , 0.09090909, 0.08333333, 0.07692308,
        0.07142857, 0.06666667, 0.0625    , 0.05882353, 0.05555556],
       [0.1       , 0.09090909, 0.08333333, 0.07692308, 0.07142857,
        0.06666667, 0.0625    , 0.05882353, 0.05555556, 0.05263158]])
#計算橫座標
n_alphas = 200
alphas = np.logspace(-10, -2, n_alphas)
alphas
array([1.00000000e-10, 1.09698580e-10, 1.20337784e-10, 1.32008840e-10,
       1.44811823e-10, 1.58856513e-10, 1.74263339e-10, 1.91164408e-10,
       2.09704640e-10, 2.30043012e-10, 2.52353917e-10, 2.76828663e-10,
       3.03677112e-10, 3.33129479e-10, 3.65438307e-10, 4.00880633e-10,
       4.39760361e-10, 4.82410870e-10, 5.29197874e-10, 5.80522552e-10,
       6.36824994e-10, 6.98587975e-10, 7.66341087e-10, 8.40665289e-10,
       9.22197882e-10, 1.01163798e-09, 1.10975250e-09, 1.21738273e-09,
       1.33545156e-09, 1.46497140e-09, 1.60705282e-09, 1.76291412e-09,
       1.93389175e-09, 2.12145178e-09, 2.32720248e-09, 2.55290807e-09,
       2.80050389e-09, 3.07211300e-09, 3.37006433e-09, 3.69691271e-09,
       4.05546074e-09, 4.44878283e-09, 4.88025158e-09, 5.35356668e-09,
       5.87278661e-09, 6.44236351e-09, 7.06718127e-09, 7.75259749e-09,
       8.50448934e-09, 9.32930403e-09, 1.02341140e-08, 1.12266777e-08,
       1.23155060e-08, 1.35099352e-08, 1.48202071e-08, 1.62575567e-08,
       1.78343088e-08, 1.95639834e-08, 2.14614120e-08, 2.35428641e-08,
       2.58261876e-08, 2.83309610e-08, 3.10786619e-08, 3.40928507e-08,
       3.73993730e-08, 4.10265811e-08, 4.50055768e-08, 4.93704785e-08,
       5.41587138e-08, 5.94113398e-08, 6.51733960e-08, 7.14942899e-08,
       7.84282206e-08, 8.60346442e-08, 9.43787828e-08, 1.03532184e-07,
       1.13573336e-07, 1.24588336e-07, 1.36671636e-07, 1.49926843e-07,
       1.64467618e-07, 1.80418641e-07, 1.97916687e-07, 2.17111795e-07,
       2.38168555e-07, 2.61267523e-07, 2.86606762e-07, 3.14403547e-07,
       3.44896226e-07, 3.78346262e-07, 4.15040476e-07, 4.55293507e-07,
       4.99450512e-07, 5.47890118e-07, 6.01027678e-07, 6.59318827e-07,
       7.23263390e-07, 7.93409667e-07, 8.70359136e-07, 9.54771611e-07,
       1.04737090e-06, 1.14895100e-06, 1.26038293e-06, 1.38262217e-06,
       1.51671689e-06, 1.66381689e-06, 1.82518349e-06, 2.00220037e-06,
       2.19638537e-06, 2.40940356e-06, 2.64308149e-06, 2.89942285e-06,
       3.18062569e-06, 3.48910121e-06, 3.82749448e-06, 4.19870708e-06,
       4.60592204e-06, 5.05263107e-06, 5.54266452e-06, 6.08022426e-06,
       6.66991966e-06, 7.31680714e-06, 8.02643352e-06, 8.80488358e-06,
       9.65883224e-06, 1.05956018e-05, 1.16232247e-05, 1.27505124e-05,
       1.39871310e-05, 1.53436841e-05, 1.68318035e-05, 1.84642494e-05,
       2.02550194e-05, 2.22194686e-05, 2.43744415e-05, 2.67384162e-05,
       2.93316628e-05, 3.21764175e-05, 3.52970730e-05, 3.87203878e-05,
       4.24757155e-05, 4.65952567e-05, 5.11143348e-05, 5.60716994e-05,
       6.15098579e-05, 6.74754405e-05, 7.40196000e-05, 8.11984499e-05,
       8.90735464e-05, 9.77124154e-05, 1.07189132e-04, 1.17584955e-04,
       1.28989026e-04, 1.41499130e-04, 1.55222536e-04, 1.70276917e-04,
       1.86791360e-04, 2.04907469e-04, 2.24780583e-04, 2.46581108e-04,
       2.70495973e-04, 2.96730241e-04, 3.25508860e-04, 3.57078596e-04,
       3.91710149e-04, 4.29700470e-04, 4.71375313e-04, 5.17092024e-04,
       5.67242607e-04, 6.22257084e-04, 6.82607183e-04, 7.48810386e-04,
       8.21434358e-04, 9.01101825e-04, 9.88495905e-04, 1.08436597e-03,
       1.18953407e-03, 1.30490198e-03, 1.43145894e-03, 1.57029012e-03,
       1.72258597e-03, 1.88965234e-03, 2.07292178e-03, 2.27396575e-03,
       2.49450814e-03, 2.73644000e-03, 3.00183581e-03, 3.29297126e-03,
       3.61234270e-03, 3.96268864e-03, 4.34701316e-03, 4.76861170e-03,
       5.23109931e-03, 5.73844165e-03, 6.29498899e-03, 6.90551352e-03,
       7.57525026e-03, 8.30994195e-03, 9.11588830e-03, 1.00000000e-02])
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['simhei'] #顯示中文
plt.rcParams['axes.unicode_minus']=False # 用來正常顯示負號  
%matplotlib inline
#建模,獲取每一個正則化取值下的係數組合
coefs = []
for a in alphas:
    ridge = linear_model.Ridge(alpha=a, fit_intercept=False)#     
    ridge.fit(X, y)
    coefs.append(ridge.coef_)
    
#繪圖展示結果
ax = plt.gca()# plt.plot()實際上會通過plt.gca()獲得當前的Axes對象ax,然後再調用ax.plot()方法實現真正的繪圖。
ax.plot(alphas, coefs)
ax.set_xscale('log')
ax.set_xlim(ax.get_xlim()[::-1])#將橫座標逆轉
plt.xlabel('正則化參數alpha')
plt.ylabel('係數w')
plt.title('嶺迴歸下的嶺跡圖')
plt.axis('tight')
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-boCmupfy-1593261098357)(output_72_1.png)]

import numpy as np
import pandas as pd
from sklearn.linear_model import RidgeCV, LinearRegression 
from sklearn.model_selection import train_test_split as TTS 
from sklearn.datasets import fetch_california_housing as fch 
import matplotlib.pyplot as plt

housevalue = fch()
X = pd.DataFrame(housevalue.data)
y = housevalue.target
X.columns = ["住戶收入中位數","房屋使用年代中位數","平均房間數目","平均臥室數目","街區人口","平均入住率","街區的緯度","街區的經度"]
Ridge_ = RidgeCV(alphas=np.arange(1,1001,100),store_cv_values=True).fit(X, y)
#無關交叉驗證的嶺迴歸結果
Ridge_.score(X,y)
0.6060251767338429
#調用所有交叉驗證的結果
Ridge_.cv_values_.shape
(20640, 10)
#進行平均後可以查看每個正則化係數取值下的交叉驗證結果
Ridge_.cv_values_.mean(axis=0)
array([0.52823795, 0.52787439, 0.52807763, 0.52855759, 0.52917958,
       0.52987689, 0.53061486, 0.53137481, 0.53214638, 0.53292369])
#查看被選擇出來的最佳正則化係數
Ridge_.alpha_
101

4.3.2 Lasso的核心作用:特徵選擇

import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge, LinearRegression, Lasso 
from sklearn.model_selection import train_test_split as TTS
from sklearn.datasets import fetch_california_housing as fch 
import matplotlib.pyplot as plt
housevalue = fch()
X = pd.DataFrame(housevalue.data)
y = housevalue.target
X.columns = ["住戶收入中位數","房屋使用年代中位數","平均房間數目","平均臥室數目","街區人口","平均入住率","街區的緯度","街區的經度"]
X.head()
住戶收入中位數 房屋使用年代中位數 平均房間數目 平均臥室數目 街區人口 平均入住率 街區的緯度 街區的經度
0 8.3252 41.0 6.984127 1.023810 322.0 2.555556 37.88 -122.23
1 8.3014 21.0 6.238137 0.971880 2401.0 2.109842 37.86 -122.22
2 7.2574 52.0 8.288136 1.073446 496.0 2.802260 37.85 -122.24
3 5.6431 52.0 5.817352 1.073059 558.0 2.547945 37.85 -122.25
4 3.8462 52.0 6.281853 1.081081 565.0 2.181467 37.85 -122.25
Xtrain,Xtest,Ytrain,Ytest = TTS(X,y,test_size=0.3,random_state=420)
#恢復索引
for i in [Xtrain,Xtest]:
    i.index = range(i.shape[0])
#線性迴歸進行擬合
reg = LinearRegression().fit(Xtrain,Ytrain)
(reg.coef_*100).tolist()
[43.735893059684,
 1.0211268294493883,
 -10.780721617317667,
 62.64338275363759,
 5.216125353348089e-05,
 -0.3348509646333704,
 -41.30959378947717,
 -42.62109536208464]
reg.coef_*100
array([ 4.37358931e+01,  1.02112683e+00, -1.07807216e+01,  6.26433828e+01,
        5.21612535e-05, -3.34850965e-01, -4.13095938e+01, -4.26210954e+01])
#嶺迴歸進行擬合
Ridge_ = Ridge(alpha=0).fit(Xtrain,Ytrain)
(Ridge_.coef_*100).tolist()
[43.735893059684024,
 1.0211268294494151,
 -10.780721617317592,
 62.64338275363727,
 5.2161253532709486e-05,
 -0.3348509646333586,
 -41.30959378947672,
 -42.62109536208427]
#Lasso進行擬合
lasso_ = Lasso(alpha=0).fit(Xtrain,Ytrain)
(lasso_.coef_*100).tolist()
D:\ProgramData\Anaconda3\lib\site-packages\ipykernel_launcher.py:2: UserWarning: With alpha=0, this algorithm does not converge well. You are advised to use the LinearRegression estimator
  
D:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\_coordinate_descent.py:476: UserWarning: Coordinate descent with no regularization may lead to unexpected results and is discouraged.
  positive)
D:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\_coordinate_descent.py:476: ConvergenceWarning: Objective did not converge. You might want to increase the number of iterations. Duality gap: 3769.8607714139175, tolerance: 1.9172554769131482
  positive)





[43.73589305968398,
 1.0211268294494045,
 -10.780721617317642,
 62.64338275363768,
 5.2161253532676174e-05,
 -0.33485096463335784,
 -41.30959378947721,
 -42.62109536208479]

可以看到,嶺迴歸沒有報出錯誤,但Lasso就不一樣了,雖然依然對係數進行了計算,但是報出了整整三個紅條:

這三條分別是這樣的內容:

    1. 正則化係數爲0,這樣算法不可收斂!如果你想讓正則化係數爲0,請使用線性迴歸吧
    1. 沒有正則項的座標下降法可能會導致意外的結果,不鼓勵這樣做!
    1. 目標函數沒有收斂,你也許想要增加迭代次數,使用一個非常小的alpha來擬合模型可能會造成精確度問題!

有了座標下降,就有迭代和收斂的問題,因此sklearn不推薦我們使用0這樣的正則化係數。如果我們的確希望取到0,那我們可以使用一個比較很小的數,比如0.01,或者 這樣的值:

#嶺迴歸進行擬合
Ridge_ = Ridge(alpha=0.01).fit(Xtrain,Ytrain)
(Ridge_.coef_*100).tolist()
[43.735757206215965,
 1.0211292318121794,
 -10.780460336251618,
 62.64202320775656,
 5.217068073242219e-05,
 -0.3348506517067619,
 -41.309571432291364,
 -42.62105388932401]
#Lasso進行擬合
lasso_ = Lasso(alpha=0.01).fit(Xtrain,Ytrain)
(lasso_.coef_*100).tolist()
[40.105683718344864,
 1.0936292607860143,
 -3.7423763610244585,
 26.524037834897218,
 0.00035253685115039596,
 -0.3207129394887797,
 -40.064830473448424,
 -40.81754399163315]
#加大正則項係數,觀察模型的係數發生了什麼變化
Ridge_ = Ridge(alpha=10**4).fit(Xtrain,Ytrain)
(Ridge_.coef_*100).tolist()
[34.62081517607694,
 1.5196170869238694,
 0.3968610529210159,
 0.9151812510354821,
 0.002173923801224847,
 -0.3476866014810102,
 -14.736963474215234,
 -13.435576102526895]
lasso_ = Lasso(alpha=10**4).fit(Xtrain,Ytrain)
(lasso_.coef_*100).tolist()
[0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0]
#看來10**4對於Lasso來說是一個過於大的取值
lasso_ = Lasso(alpha=1).fit(Xtrain,Ytrain)
(lasso_.coef_*100).tolist()
[14.581141247629418,
 0.6209347344423873,
 0.0,
 -0.0,
 -0.0002806598632901005,
 -0.0,
 -0.0,
 -0.0]
#將係數進行繪圖
plt.plot(range(1,9),(reg.coef_*100).tolist(),color="red",label="LR")
plt.plot(range(1,9),(Ridge_.coef_*100).tolist(),color="orange",label="Ridge") 
plt.plot(range(1,9),(lasso_.coef_*100).tolist(),color="k",label="Lasso") 
plt.plot(range(1,9),[0]*8,color="grey",linestyle="--")
plt.xlabel('w') #橫座標是每一個特徵所對應的係數
plt.legend()
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-hoYmzFDJ-1593261098357)(output_97_0.png)]

可見,比起嶺迴歸,Lasso所帶的L1正則項對於係數的懲罰要重得多,並且它會將係數壓縮至0,因此可以被用來做特徵選擇。也因此,我們往往讓Lasso的正則化係數 在很小的空間中變動,以此來尋找最佳的正則化係數。

4.3.3 選取最佳的正則化參數取值

from sklearn.linear_model import LassoCV
#自己建立Lasso進行alpha選擇的範圍
alpharange = np.logspace(-10, -2, 200,base=10)
#其實是形成10爲底的指數函數
#10**(-10)到10**(-2)次方
alpharange.shape
(200,)
Xtrain.head()
住戶收入中位數 房屋使用年代中位數 平均房間數目 平均臥室數目 街區人口 平均入住率 街區的緯度 街區的經度
0 4.1776 35.0 4.425172 1.030683 5380.0 3.368817 37.48 -122.19
1 5.3261 38.0 6.267516 1.089172 429.0 2.732484 37.53 -122.30
2 1.9439 26.0 5.768977 1.141914 891.0 2.940594 36.02 -119.08
3 2.5000 22.0 4.916000 1.012000 733.0 2.932000 38.57 -121.31
4 3.8250 34.0 5.036765 1.098039 1134.0 2.779412 33.91 -118.35
lasso_ = LassoCV(alphas=alpharange #自行輸入的alpha的取值範圍
                 ,cv=5 #交叉驗證的折數
                ).fit(Xtrain, Ytrain)
#查看被選擇出來的最佳正則化係數
lasso_.alpha_
0.0020729217795953697
#調用所有交叉驗證的結果
lasso_.mse_path_
array([[0.52454913, 0.49856261, 0.55984312, 0.50526576, 0.55262557],
       [0.52361933, 0.49748809, 0.55887637, 0.50429373, 0.55283734],
       [0.52281927, 0.49655113, 0.55803797, 0.5034594 , 0.55320522],
       [0.52213811, 0.49574741, 0.55731858, 0.50274517, 0.55367515],
       [0.52155715, 0.49505688, 0.55669995, 0.50213252, 0.55421553],
       [0.52106069, 0.49446226, 0.55616707, 0.50160604, 0.55480104],
       [0.5206358 , 0.49394903, 0.55570702, 0.50115266, 0.55541214],
       [0.52027135, 0.49350539, 0.55530895, 0.50076146, 0.55603333],
       [0.51995825, 0.49312085, 0.5549639 , 0.50042318, 0.55665306],
       [0.5196886 , 0.49278705, 0.55466406, 0.50013007, 0.55726225],
       [0.51945602, 0.49249647, 0.55440306, 0.49987554, 0.55785451],
       [0.51925489, 0.49224316, 0.55417527, 0.49965404, 0.55842496],
       [0.51908068, 0.49202169, 0.55397615, 0.49946088, 0.55897049],
       [0.51892938, 0.49182782, 0.55380162, 0.49929206, 0.55948886],
       [0.51879778, 0.49165759, 0.55364841, 0.49914421, 0.55997905],
       [0.51868299, 0.49150788, 0.55351357, 0.49901446, 0.5604405 ],
       [0.51858268, 0.49137604, 0.55339469, 0.49890035, 0.56087323],
       [0.51849488, 0.49125956, 0.55328972, 0.4987998 , 0.56127784],
       [0.5184178 , 0.49115652, 0.55319678, 0.49871101, 0.56165507],
       [0.51835002, 0.49106526, 0.55311438, 0.49863248, 0.5620059 ],
       [0.51829033, 0.49098418, 0.55304118, 0.49856287, 0.56233145],
       [0.51823761, 0.49091208, 0.55297609, 0.49850108, 0.56263308],
       [0.51819098, 0.49084785, 0.55291806, 0.49844612, 0.56291204],
       [0.51814966, 0.49079058, 0.55286626, 0.49839716, 0.56316966],
       [0.51811298, 0.49073937, 0.55281996, 0.49835348, 0.56340721],
       [0.51808038, 0.49069355, 0.55277854, 0.49831445, 0.5636261 ],
       [0.51805132, 0.49065249, 0.5527414 , 0.49827953, 0.56382754],
       [0.5180254 , 0.49061566, 0.55270806, 0.49824828, 0.56401276],
       [0.51800224, 0.49058258, 0.55267812, 0.49822015, 0.56418292],
       [0.51798152, 0.49055285, 0.55265118, 0.49819493, 0.56433912],
       [0.51796296, 0.49052608, 0.55262693, 0.49817225, 0.56448243],
       [0.5179463 , 0.49050195, 0.55260507, 0.49815185, 0.56461379],
       [0.51793135, 0.49048019, 0.55258536, 0.49813345, 0.5647342 ],
       [0.51791791, 0.49046055, 0.55256757, 0.49811687, 0.56484448],
       [0.5179058 , 0.49044281, 0.55255149, 0.4981019 , 0.56494544],
       [0.5178949 , 0.49042677, 0.55253695, 0.49808838, 0.56503784],
       [0.51788506, 0.49041226, 0.55252379, 0.49807615, 0.56512236],
       [0.51787619, 0.49039913, 0.55251189, 0.4980651 , 0.56519967],
       [0.51786817, 0.49038724, 0.5525011 , 0.49805509, 0.56527034],
       [0.51786092, 0.49037646, 0.55249132, 0.49804603, 0.56533494],
       [0.51785437, 0.49036669, 0.55248246, 0.49803782, 0.56539397],
       [0.51784843, 0.49035783, 0.55247442, 0.49803037, 0.5654479 ],
       [0.51784306, 0.49034979, 0.55246712, 0.49802362, 0.56549716],
       [0.51783819, 0.49034249, 0.5524605 , 0.49801749, 0.56554215],
       [0.51783377, 0.49033586, 0.55245448, 0.49801193, 0.56558322],
       [0.51782977, 0.49032984, 0.55244901, 0.49800688, 0.56562073],
       [0.51782614, 0.49032437, 0.55244405, 0.49800229, 0.56565496],
       [0.51782284, 0.49031939, 0.55243953, 0.49799812, 0.56568621],
       [0.51781984, 0.49031487, 0.55243543, 0.49799434, 0.56571472],
       [0.51781712, 0.49031076, 0.55243169, 0.49799089, 0.56574074],
       [0.51781465, 0.49030702, 0.5524283 , 0.49798776, 0.56576449],
       [0.5178124 , 0.49030362, 0.55242521, 0.49798491, 0.56578615],
       [0.51781036, 0.49030052, 0.5524224 , 0.49798232, 0.56580591],
       [0.5178085 , 0.4902977 , 0.55241984, 0.49797996, 0.56582394],
       [0.51780681, 0.49029514, 0.55241751, 0.49797781, 0.56584039],
       [0.51780528, 0.4902928 , 0.55241539, 0.49797586, 0.56585539],
       [0.51780388, 0.49029068, 0.55241346, 0.49797408, 0.56586907],
       [0.51780261, 0.49028874, 0.55241171, 0.49797246, 0.56588155],
       [0.51780145, 0.49028698, 0.55241011, 0.49797099, 0.56589293],
       [0.51780039, 0.49028538, 0.55240865, 0.49796965, 0.56590331],
       [0.51779943, 0.49028392, 0.55240732, 0.49796843, 0.56591277],
       [0.51779856, 0.49028258, 0.55240611, 0.49796731, 0.5659214 ],
       [0.51779777, 0.49028137, 0.55240501, 0.4979663 , 0.56592927],
       [0.51779704, 0.49028027, 0.55240401, 0.49796538, 0.56593645],
       [0.51779638, 0.49027926, 0.5524031 , 0.49796454, 0.56594299],
       [0.51779578, 0.49027834, 0.55240226, 0.49796377, 0.56594896],
       [0.51779523, 0.49027751, 0.55240151, 0.49796307, 0.5659544 ],
       [0.51779473, 0.49027675, 0.55240081, 0.49796243, 0.56595936],
       [0.51779428, 0.49027605, 0.55240018, 0.49796185, 0.56596388],
       [0.51779386, 0.49027542, 0.55239961, 0.49796133, 0.565968  ],
       [0.51779349, 0.49027485, 0.55239909, 0.49796085, 0.56597176],
       [0.51779314, 0.49027432, 0.55239861, 0.49796041, 0.56597519],
       [0.51779283, 0.49027384, 0.55239818, 0.49796001, 0.56597831],
       [0.51779254, 0.49027341, 0.55239778, 0.49795964, 0.56598116],
       [0.51779228, 0.49027301, 0.55239742, 0.49795931, 0.56598376],
       [0.51779205, 0.49027265, 0.55239709, 0.49795901, 0.56598613],
       [0.51779183, 0.49027232, 0.55239679, 0.49795873, 0.56598828],
       [0.51779163, 0.49027202, 0.55239652, 0.49795848, 0.56599025],
       [0.51779146, 0.49027174, 0.55239627, 0.49795825, 0.56599205],
       [0.51779129, 0.49027149, 0.55239604, 0.49795804, 0.56599368],
       [0.51779114, 0.49027127, 0.55239584, 0.49795785, 0.56599517],
       [0.51779101, 0.49027106, 0.55239565, 0.49795768, 0.56599653],
       [0.51779088, 0.49027087, 0.55239548, 0.49795752, 0.56599777],
       [0.51779077, 0.4902707 , 0.55239532, 0.49795738, 0.5659989 ],
       [0.51779067, 0.49027054, 0.55239518, 0.49795725, 0.56599993],
       [0.51779057, 0.4902704 , 0.55239505, 0.49795713, 0.56600087],
       [0.51779049, 0.49027027, 0.55239493, 0.49795702, 0.56600172],
       [0.51779041, 0.49027015, 0.55239482, 0.49795692, 0.5660025 ],
       [0.51779034, 0.49027004, 0.55239472, 0.49795683, 0.56600322],
       [0.51779027, 0.49026994, 0.55239463, 0.49795675, 0.56600386],
       [0.51779022, 0.49026985, 0.55239455, 0.49795667, 0.56600446],
       [0.51779016, 0.49026977, 0.55239448, 0.4979566 , 0.56600499],
       [0.51779011, 0.49026969, 0.55239441, 0.49795654, 0.56600549],
       [0.51779007, 0.49026962, 0.55239435, 0.49795648, 0.56600593],
       [0.51779003, 0.49026956, 0.55239429, 0.49795643, 0.56600634],
       [0.51778999, 0.49026951, 0.55239424, 0.49795638, 0.56600671],
       [0.51778996, 0.49026945, 0.55239419, 0.49795634, 0.56600705],
       [0.51778993, 0.49026941, 0.55239415, 0.4979563 , 0.56600736],
       [0.5177899 , 0.49026936, 0.55239411, 0.49795626, 0.56600764],
       [0.51778987, 0.49026932, 0.55239407, 0.49795623, 0.5660079 ],
       [0.51778985, 0.49026929, 0.55239404, 0.4979562 , 0.56600813],
       [0.51778983, 0.49026926, 0.55239401, 0.49795617, 0.56600835],
       [0.51778981, 0.49026923, 0.55239398, 0.49795615, 0.56600854],
       [0.51778979, 0.4902692 , 0.55239396, 0.49795613, 0.56600872],
       [0.51778977, 0.49026918, 0.55239394, 0.49795611, 0.56600888],
       [0.51778976, 0.49026915, 0.55239392, 0.49795609, 0.56600903],
       [0.51778975, 0.49026913, 0.5523939 , 0.49795607, 0.56600916],
       [0.51778973, 0.49026911, 0.55239388, 0.49795605, 0.56600929],
       [0.51778972, 0.4902691 , 0.55239387, 0.49795604, 0.5660094 ],
       [0.51778971, 0.49026908, 0.55239385, 0.49795603, 0.5660095 ],
       [0.5177897 , 0.49026907, 0.55239384, 0.49795602, 0.56600959],
       [0.5177897 , 0.49026905, 0.55239383, 0.49795601, 0.56600968],
       [0.51778969, 0.49026904, 0.55239382, 0.497956  , 0.56600975],
       [0.51778968, 0.49026903, 0.55239381, 0.49795599, 0.56600983],
       [0.51778967, 0.49026902, 0.5523938 , 0.49795598, 0.56600989],
       [0.51778967, 0.49026901, 0.55239379, 0.49795597, 0.56600995],
       [0.51778966, 0.490269  , 0.55239378, 0.49795596, 0.56601   ],
       [0.51778966, 0.490269  , 0.55239378, 0.49795596, 0.56601005],
       [0.51778965, 0.49026899, 0.55239377, 0.49795595, 0.56601009],
       [0.51778965, 0.49026898, 0.55239376, 0.49795595, 0.56601013],
       [0.51778965, 0.49026898, 0.55239376, 0.49795594, 0.56601017],
       [0.51778964, 0.49026897, 0.55239375, 0.49795594, 0.5660102 ],
       [0.51778964, 0.49026897, 0.55239375, 0.49795593, 0.56601023],
       [0.51778964, 0.49026896, 0.55239375, 0.49795593, 0.56601026],
       [0.51778963, 0.49026896, 0.55239374, 0.49795593, 0.56601029],
       [0.51778963, 0.49026896, 0.55239374, 0.49795592, 0.56601031],
       [0.51778963, 0.49026895, 0.55239374, 0.49795592, 0.56601033],
       [0.51778963, 0.49026895, 0.55239373, 0.49795592, 0.56601035],
       [0.51778963, 0.49026895, 0.55239373, 0.49795592, 0.56601037],
       [0.51778962, 0.49026895, 0.55239373, 0.49795591, 0.56601039],
       [0.51778962, 0.49026894, 0.55239373, 0.49795591, 0.5660104 ],
       [0.51778962, 0.49026894, 0.55239372, 0.49795591, 0.56601041],
       [0.51778962, 0.49026894, 0.55239372, 0.49795591, 0.56601043],
       [0.51778962, 0.49026894, 0.55239372, 0.49795591, 0.56601044],
       [0.51778962, 0.49026894, 0.55239372, 0.49795591, 0.56601045],
       [0.51778962, 0.49026894, 0.55239372, 0.49795591, 0.56601046],
       [0.51778962, 0.49026893, 0.55239372, 0.4979559 , 0.56601046],
       [0.51778962, 0.49026893, 0.55239372, 0.4979559 , 0.56601047],
       [0.51778962, 0.49026893, 0.55239372, 0.4979559 , 0.56601048],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.56601048],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.56601049],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.5660105 ],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.5660105 ],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.5660105 ],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.56601051],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.56601051],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.56601052],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.56601052],
       [0.51778961, 0.49026893, 0.55239371, 0.4979559 , 0.56601052],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601052],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601053],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601053],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601053],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601053],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601053],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.4979559 , 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601054],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055],
       [0.51778961, 0.49026892, 0.55239371, 0.49795589, 0.56601055]])
lasso_.mse_path_.shape #返回每個alpha下的五折交叉驗證結果
(200, 5)
lasso_.mse_path_.mean(axis=1) #有注意到在嶺迴歸中我們的軸向是axis=0嗎?
#在嶺迴歸當中,我們是留一驗證,因此我們的交叉驗證結果返回的是,每一個樣本在每個alpha下的交叉驗證結果
#因此我們要求每個alpha下的交叉驗證均值,就是axis=0,跨行求均值
#而在這裏,我們返回的是,每一個alpha取值下,每一折交叉驗證的結果
#因此我們要求每個alpha下的交叉驗證均值,就是axis=1,跨列求均值
array([0.52816924, 0.52742297, 0.5268146 , 0.52632488, 0.52593241,
       0.52561942, 0.52537133, 0.5251761 , 0.52502385, 0.52490641,
       0.52481712, 0.52475046, 0.52470198, 0.52466795, 0.52464541,
       0.52463188, 0.5246254 , 0.52462436, 0.52462744, 0.52463361,
       0.52464201, 0.52465199, 0.52466301, 0.52467466, 0.5246866 ,
       0.5246986 , 0.52471046, 0.52472203, 0.5247332 , 0.52474392,
       0.52475413, 0.52476379, 0.52477291, 0.52478147, 0.52478949,
       0.52479697, 0.52480393, 0.52481039, 0.52481639, 0.52482193,
       0.52482706, 0.52483179, 0.52483615, 0.52484016, 0.52484385,
       0.52484725, 0.52485036, 0.52485322, 0.52485584, 0.52485824,
       0.52486044, 0.52486246, 0.5248643 , 0.52486599, 0.52486753,
       0.52486895, 0.52487024, 0.52487141, 0.52487249, 0.52487348,
       0.52487437, 0.52487519, 0.52487594, 0.52487663, 0.52487725,
       0.52487782, 0.52487834, 0.52487882, 0.52487925, 0.52487965,
       0.52488001, 0.52488033, 0.52488063, 0.52488091, 0.52488116,
       0.52488138, 0.52488159, 0.52488178, 0.52488195, 0.52488211,
       0.52488225, 0.52488239, 0.5248825 , 0.52488261, 0.52488271,
       0.5248828 , 0.52488289, 0.52488296, 0.52488303, 0.52488309,
       0.52488315, 0.5248832 , 0.52488325, 0.52488329, 0.52488333,
       0.52488337, 0.5248834 , 0.52488343, 0.52488346, 0.52488348,
       0.5248835 , 0.52488352, 0.52488354, 0.52488356, 0.52488357,
       0.52488359, 0.5248836 , 0.52488361, 0.52488362, 0.52488363,
       0.52488364, 0.52488365, 0.52488366, 0.52488367, 0.52488367,
       0.52488368, 0.52488368, 0.52488369, 0.52488369, 0.5248837 ,
       0.5248837 , 0.5248837 , 0.52488371, 0.52488371, 0.52488371,
       0.52488371, 0.52488371, 0.52488372, 0.52488372, 0.52488372,
       0.52488372, 0.52488372, 0.52488372, 0.52488372, 0.52488373,
       0.52488373, 0.52488373, 0.52488373, 0.52488373, 0.52488373,
       0.52488373, 0.52488373, 0.52488373, 0.52488373, 0.52488373,
       0.52488373, 0.52488373, 0.52488373, 0.52488373, 0.52488373,
       0.52488373, 0.52488373, 0.52488373, 0.52488373, 0.52488373,
       0.52488373, 0.52488373, 0.52488373, 0.52488373, 0.52488373,
       0.52488373, 0.52488373, 0.52488374, 0.52488374, 0.52488374,
       0.52488374, 0.52488374, 0.52488374, 0.52488374, 0.52488374,
       0.52488374, 0.52488374, 0.52488374, 0.52488374, 0.52488374,
       0.52488374, 0.52488374, 0.52488374, 0.52488374, 0.52488374,
       0.52488374, 0.52488374, 0.52488374, 0.52488374, 0.52488374,
       0.52488374, 0.52488374, 0.52488374, 0.52488374, 0.52488374,
       0.52488374, 0.52488374, 0.52488374, 0.52488374, 0.52488374,
       0.52488374, 0.52488374, 0.52488374, 0.52488374, 0.52488374])
#最佳正則化係數下獲得的模型的係數結果
lasso_.coef_
array([ 4.29867301e-01,  1.03623683e-02, -9.32648616e-02,  5.51755252e-01,
        1.14732262e-06, -3.31941716e-03, -4.10451223e-01, -4.22410330e-01])
lasso_.score(Xtest,Ytest)
0.6038982670571438
#與線性迴歸相比如何?
reg = LinearRegression().fit(Xtrain,Ytrain)
reg.score(Xtest,Ytest)
0.6043668160178817
#使用lassoCV自帶的正則化路徑長度和路徑中的alpha個數來自動建立alpha選擇的範圍
ls_ = LassoCV(eps=0.00001,n_alphas=300,cv=5).fit(Xtrain, Ytrain)
ls_.alpha_
0.0020954551690628535
ls_.alphas_ #查看所有自動生成的alpha取值
array([2.94059737e+01, 2.82952253e+01, 2.72264331e+01, 2.61980122e+01,
       2.52084378e+01, 2.42562424e+01, 2.33400142e+01, 2.24583946e+01,
       2.16100763e+01, 2.07938014e+01, 2.00083596e+01, 1.92525862e+01,
       1.85253605e+01, 1.78256042e+01, 1.71522798e+01, 1.65043887e+01,
       1.58809704e+01, 1.52811004e+01, 1.47038891e+01, 1.41484809e+01,
       1.36140520e+01, 1.30998100e+01, 1.26049924e+01, 1.21288655e+01,
       1.16707233e+01, 1.12298864e+01, 1.08057012e+01, 1.03975388e+01,
       1.00047937e+01, 9.62688384e+00, 9.26324869e+00, 8.91334908e+00,
       8.57666619e+00, 8.25270079e+00, 7.94097249e+00, 7.64101907e+00,
       7.35239575e+00, 7.07467457e+00, 6.80744372e+00, 6.55030695e+00,
       6.30288297e+00, 6.06480491e+00, 5.83571975e+00, 5.61528779e+00,
       5.40318218e+00, 5.19908842e+00, 5.00270386e+00, 4.81373731e+00,
       4.63190858e+00, 4.45694804e+00, 4.28859627e+00, 4.12660362e+00,
       3.97072991e+00, 3.82074399e+00, 3.67642348e+00, 3.53755437e+00,
       3.40393074e+00, 3.27535446e+00, 3.15163488e+00, 3.03258855e+00,
       2.91803894e+00, 2.80781620e+00, 2.70175688e+00, 2.59970374e+00,
       2.50150543e+00, 2.40701636e+00, 2.31609642e+00, 2.22861078e+00,
       2.14442973e+00, 2.06342843e+00, 1.98548679e+00, 1.91048923e+00,
       1.83832455e+00, 1.76888573e+00, 1.70206982e+00, 1.63777773e+00,
       1.57591415e+00, 1.51638733e+00, 1.45910901e+00, 1.40399425e+00,
       1.35096134e+00, 1.29993164e+00, 1.25082947e+00, 1.20358204e+00,
       1.15811928e+00, 1.11437377e+00, 1.07228066e+00, 1.03177753e+00,
       9.92804320e-01, 9.55303239e-01, 9.19218682e-01, 8.84497142e-01,
       8.51087135e-01, 8.18939121e-01, 7.88005430e-01, 7.58240193e-01,
       7.29599275e-01, 7.02040207e-01, 6.75522125e-01, 6.50005707e-01,
       6.25453118e-01, 6.01827951e-01, 5.79095174e-01, 5.57221080e-01,
       5.36173234e-01, 5.15920425e-01, 4.96432623e-01, 4.77680932e-01,
       4.59637546e-01, 4.42275711e-01, 4.25569683e-01, 4.09494689e-01,
       3.94026894e-01, 3.79143363e-01, 3.64822025e-01, 3.51041645e-01,
       3.37781790e-01, 3.25022798e-01, 3.12745750e-01, 3.00932442e-01,
       2.89565356e-01, 2.78627638e-01, 2.68103069e-01, 2.57976043e-01,
       2.48231544e-01, 2.38855123e-01, 2.29832877e-01, 2.21151426e-01,
       2.12797900e-01, 2.04759910e-01, 1.97025538e-01, 1.89583315e-01,
       1.82422207e-01, 1.75531594e-01, 1.68901260e-01, 1.62521372e-01,
       1.56382472e-01, 1.50475455e-01, 1.44791563e-01, 1.39322368e-01,
       1.34059761e-01, 1.28995937e-01, 1.24123389e-01, 1.19434891e-01,
       1.14923491e-01, 1.10582499e-01, 1.06405479e-01, 1.02386238e-01,
       9.85188143e-02, 9.47974747e-02, 9.12167008e-02, 8.77711831e-02,
       8.44558125e-02, 8.12656730e-02, 7.81960343e-02, 7.52423447e-02,
       7.24002244e-02, 6.96654592e-02, 6.70339940e-02, 6.45019268e-02,
       6.20655031e-02, 5.97211101e-02, 5.74652717e-02, 5.52946427e-02,
       5.32060046e-02, 5.11962605e-02, 4.92624301e-02, 4.74016461e-02,
       4.56111493e-02, 4.38882847e-02, 4.22304977e-02, 4.06353301e-02,
       3.91004165e-02, 3.76234811e-02, 3.62023337e-02, 3.48348672e-02,
       3.35190539e-02, 3.22529426e-02, 3.10346560e-02, 2.98623876e-02,
       2.87343991e-02, 2.76490180e-02, 2.66046349e-02, 2.55997012e-02,
       2.46327267e-02, 2.37022776e-02, 2.28069742e-02, 2.19454891e-02,
       2.11165447e-02, 2.03189119e-02, 1.95514080e-02, 1.88128950e-02,
       1.81022777e-02, 1.74185025e-02, 1.67605555e-02, 1.61274610e-02,
       1.55182803e-02, 1.49321101e-02, 1.43680812e-02, 1.38253574e-02,
       1.33031338e-02, 1.28006361e-02, 1.23171192e-02, 1.18518661e-02,
       1.14041869e-02, 1.09734179e-02, 1.05589203e-02, 1.01600794e-02,
       9.77630394e-03, 9.40702475e-03, 9.05169431e-03, 8.70978573e-03,
       8.38079201e-03, 8.06422534e-03, 7.75961630e-03, 7.46651323e-03,
       7.18448150e-03, 6.91310292e-03, 6.65197510e-03, 6.40071082e-03,
       6.15893752e-03, 5.92629670e-03, 5.70244339e-03, 5.48704566e-03,
       5.27978413e-03, 5.08035147e-03, 4.88845195e-03, 4.70380102e-03,
       4.52612490e-03, 4.35516012e-03, 4.19065316e-03, 4.03236011e-03,
       3.88004625e-03, 3.73348572e-03, 3.59246120e-03, 3.45676358e-03,
       3.32619166e-03, 3.20055181e-03, 3.07965774e-03, 2.96333019e-03,
       2.85139667e-03, 2.74369120e-03, 2.64005407e-03, 2.54033162e-03,
       2.44437597e-03, 2.35204484e-03, 2.26320133e-03, 2.17771369e-03,
       2.09545517e-03, 2.01630379e-03, 1.94014218e-03, 1.86685742e-03,
       1.79634083e-03, 1.72848786e-03, 1.66319789e-03, 1.60037411e-03,
       1.53992337e-03, 1.48175602e-03, 1.42578583e-03, 1.37192979e-03,
       1.32010804e-03, 1.27024376e-03, 1.22226299e-03, 1.17609459e-03,
       1.13167011e-03, 1.08892367e-03, 1.04779188e-03, 1.00821376e-03,
       9.70130622e-04, 9.33485992e-04, 8.98225535e-04, 8.64296967e-04,
       8.31649980e-04, 8.00236162e-04, 7.70008936e-04, 7.40923479e-04,
       7.12936663e-04, 6.86006990e-04, 6.60094529e-04, 6.35160855e-04,
       6.11168999e-04, 5.88083384e-04, 5.65869780e-04, 5.44495247e-04,
       5.23928092e-04, 5.04137817e-04, 4.85095079e-04, 4.66771639e-04,
       4.49140329e-04, 4.32175004e-04, 4.15850508e-04, 4.00142636e-04,
       3.85028095e-04, 3.70484474e-04, 3.56490207e-04, 3.43024545e-04,
       3.30067519e-04, 3.17599917e-04, 3.05603253e-04, 2.94059737e-04])
ls_.alphas_.shape
(300,)
ls_.score(Xtest,Ytest)
0.6038915423819199
ls_.coef_
array([ 4.29785372e-01,  1.03639989e-02, -9.31060823e-02,  5.50940621e-01,
        1.15407943e-06, -3.31909776e-03, -4.10423420e-01, -4.22369926e-01])

線性迴歸在非線性數據上的表現如何呢?我們來建立一個明顯是非線性的數據集,並觀察線性迴歸和決策樹的而回歸在擬合非線性數據集時的表現:

  1. 導入所需要的庫
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
  1. 創建需要擬合的數據集
rnd = np.random.RandomState(42) #設置隨機數種子
X = rnd.uniform(-3, 3, size=100) #random.uniform,從輸入的任意兩個整數中取出size個隨機數
X
array([-0.75275929,  2.70428584,  1.39196365,  0.59195091, -2.06388816,
       -2.06403288, -2.65149833,  2.19705687,  0.60669007,  1.24843547,
       -2.87649303,  2.81945911,  1.99465584, -1.72596534, -1.9090502 ,
       -1.89957294, -1.17454654,  0.14853859, -0.40832989, -1.25262516,
        0.67111737, -2.16303684, -1.24713211, -0.80182894, -0.26358009,
        1.71105577, -1.80195731,  0.08540663,  0.55448741, -2.72129752,
        0.64526911, -1.97685526, -2.60969044,  2.69331322,  2.7937922 ,
        1.85038409, -1.17231738, -2.41396732,  1.10539816, -0.35908504,
       -2.26777059, -0.02893854, -2.79366887,  2.45592241, -1.44732011,
        0.97513371, -1.12973354,  0.12040813,  0.28026168, -1.89087327,
        2.81750777,  1.65079694,  2.63699365,  2.3689641 ,  0.58739987,
        2.53124541, -2.46904499, -1.82410283, -2.72863627, -1.04801802,
       -0.66793626, -1.37190581,  1.97242505, -0.85948004, -1.31439294,
        0.2561765 , -2.15445465,  1.81318188, -2.55269614,  2.92132162,
        1.63346862, -1.80770591, -2.9668673 ,  1.89276857,  1.24114406,
        1.37404301,  1.62762208, -2.55573209, -0.84920563, -2.30478564,
        2.17862056,  0.73978876, -1.01461185, -2.6186499 , -1.13410607,
       -1.04890007,  1.37763707,  0.82534483,  2.32327646, -0.16671045,
       -2.28243452,  1.27946872,  1.56471029,  0.36766319,  1.62580308,
       -0.03722642,  0.13639698, -0.43475389, -2.84748524, -2.35265144])
#生成y的思路:先使用NumPy中的函數生成一個sin函數圖像,然後再人爲添加噪音
y = np.sin(X) + rnd.normal(size=len(X)) / 3 #random.normal,生成size個服從正態分佈的隨機數
y
array([-6.54639413e-01,  3.23832143e-01,  1.01463893e+00, -1.04541922e-01,
       -9.54097511e-01, -7.61767511e-01,  2.19222347e-02,  6.37468193e-01,
        3.00653482e-01,  7.81237778e-01,  4.31286305e-02,  4.26174779e-01,
        7.34921650e-01, -8.16896281e-01, -9.10976357e-01, -6.23556402e-01,
       -1.15653261e+00,  3.87722577e-02, -5.27779796e-01, -1.43764744e+00,
        7.20568176e-01, -7.42673638e-01, -9.46371922e-01, -7.96824836e-01,
       -7.32328912e-01,  8.49964652e-01, -1.08763923e+00, -1.82122918e-01,
        4.72745679e-01, -2.73346296e-01,  1.23014210e+00, -8.60492039e-01,
       -4.21323537e-01,  4.08600304e-01, -2.98759618e-01,  9.52331322e-01,
       -9.01575520e-01,  1.55982485e-01,  8.29522626e-01, -2.50901995e-01,
       -7.78358506e-01, -4.18493846e-01,  3.99942125e-02,  8.83836261e-01,
       -7.28709177e-01,  5.24647761e-01, -4.36700367e-01, -3.47166298e-01,
        4.72226156e-01, -2.19059337e-01, -1.17373303e-02,  8.08035747e-01,
        5.16673582e-01,  5.30194678e-01,  3.73107829e-02,  5.96006368e-01,
       -9.77082123e-01, -8.10224942e-01, -7.07793691e-01, -3.49790543e-01,
       -8.80451494e-01, -1.08764023e+00,  1.19159792e+00, -1.16779132e+00,
       -8.91488367e-01,  6.89097940e-01, -1.37027995e+00,  1.03231278e+00,
       -4.68816156e-01,  4.79101740e-01,  5.85719831e-01, -1.41222014e+00,
        1.42839751e-04,  1.04760806e+00,  1.02965258e+00,  1.09618916e+00,
        7.71710944e-01, -4.75498741e-01, -6.53065072e-01, -9.80625265e-01,
        1.44281734e+00,  8.32076210e-01, -1.24637686e+00, -2.80580559e-01,
       -1.23105035e+00, -6.04513872e-01,  1.36760121e+00,  4.61220951e-01,
        1.05112144e+00, -2.83456659e-02, -4.83272973e-01,  1.59012773e+00,
        9.18185441e-01,  1.08190382e-01,  7.01982700e-01, -3.09154586e-01,
        1.10273875e-01, -3.07469874e-01, -1.97655440e-01, -4.33879904e-01])
rnd.normal(size=len(X)).max()
2.1531824575115563
rnd.normal(size=len(X)).min()
-2.301921164735585
#使用散點圖觀察建立的數據集是什麼樣子
plt.scatter(X, y,marker='o',c='k',s=20)
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-DsFizjEh-1593261098358)(output_124_0.png)]

#爲後續建模做準備:sklearn只接受二維以上數組作爲特徵矩陣的輸入
X.shape
(100,)
X = X.reshape(-1, 1)
X.shape
(100, 1)
  1. 使用原始數據進行建模
#使用原始數據進行建模
LinearR = LinearRegression().fit(X, y)
TreeR = DecisionTreeRegressor(random_state=0).fit(X, y)
#放置畫布
fig, ax1 = plt.subplots(1)
#創建測試數據:一系列分佈在橫座標上的點
line = np.linspace(-3, 3, 1000, endpoint=False).reshape(-1, 1)

#將測試數據帶入predict接口,獲得模型的擬合效果並進行繪製
ax1.plot(line, LinearR.predict(line), linewidth=2, color='green',label="linear regression")
ax1.plot(line, TreeR.predict(line), linewidth=2, color='red',label="decision tree")

#將原數據上的擬合繪製在圖像上
ax1.plot(X[:, 0], y, 'o', c='k')
#其他圖形選項
ax1.legend(loc="best")
ax1.set_ylabel("Regression output") 
ax1.set_xlabel("Input feature")
ax1.set_title("Result before discretization") 
plt.tight_layout()
plt.show()
#從這個圖像來看,可以得出什麼結果?

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-QeveXMKU-1593261098359)(output_129_0.png)]

從圖像上可以看出,線性迴歸無法擬合出這條帶噪音的正弦曲線的真實面貌,只能夠模擬出大概的趨勢,而決策樹卻通過建立複雜的模型將幾乎每個點都擬合出來了。可見,使用線性迴歸模型來擬合非線性數據的效果並不好,而決策樹這樣的模型卻擬合得太細緻,但是相比之下,還是決策樹的擬合效果更好一些。
決策樹無法寫作一個方程(我們在XGBoost章節中會詳細講解如何將決策樹定義成一個方程,但它絕對不是一個形似
的方程),它是一個典型的非線性模型,當它被用於擬合非線性數據,可以發揮奇效。其他典型的非線性模型還包括使用高斯核的支持向量機,樹的集成算法,以及一切通過三角函數,指數函數等非線性方程來建立的模型。
根據這個思路,我們也許可以這樣推斷:線性模型用於擬合線性數據,非線性模型用於擬合非線性數據。但事實上機器學習遠遠比我們想象的靈活得多,線性模型可以用來擬合非線性數據,而非線性模型也可以用來擬合線性數據,更神奇的是,有的算法沒有模型也可以處理各類數據,而有的模型可以既可以是線性,也可以是非線性模型!接下來,我們就來一一討論這些問題。

5.2 使用分箱處理非線性問題

讓線性迴歸在非線性數據上表現提升的核心方法之一是對數據進行分箱,也就是離散化。與線性迴歸相比,我們常用的一種迴歸是決策樹的迴歸。我們之前擬合過一條帶有噪音的正弦曲線以展示多元線性迴歸與決策樹的效用差異,我們來分析一下這張圖,然後再使用採取措施幫助我們的線性迴歸。

  1. 導入所需要的庫
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
  1. 創建需要擬合的數據集
rnd = np.random.RandomState(42) #設置隨機數種子
X = rnd.uniform(-3, 3, size=100) #random.uniform,從輸入的任意兩個整數中取出size個隨機數
#生成y的思路:先使用NumPy中的函數生成一個sin函數圖像,然後再人爲添加噪音
y = np.sin(X) + rnd.normal(size=len(X)) / 3 #random.normal,生成size個服從正態分佈的隨機數
#使用散點圖觀察建立的數據集是什麼樣子
plt.scatter(X, y,marker='o',c='k',s=20)
plt.show()
#爲後續建模做準備:sklearn只接受二維以上數組作爲特徵矩陣的輸入
X.shape
X = X.reshape(-1, 1)

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-vv1opzLq-1593261098360)(output_136_0.png)]

  1. 使用原始數據進行建模
#使用原始數據進行建模
LinearR = LinearRegression().fit(X, y)
TreeR = DecisionTreeRegressor(random_state=0).fit(X, y)
#放置畫布
fig, ax1 = plt.subplots(1)

#創建測試數據:一系列分佈在橫座標上的點
line = np.linspace(-3, 3, 1000, endpoint=False).reshape(-1, 1)

#將測試數據帶入predict接口,獲得模型的擬合效果並進行繪製
ax1.plot(line, LinearR.predict(line), linewidth=2, color='green',label="linear regression")
ax1.plot(line, TreeR.predict(line), linewidth=2, color='red',label="decision tree")
#將原數據上的擬合繪製在圖像上
ax1.plot(X[:, 0], y, 'o', c='k')
#其他圖形選項
ax1.legend(loc="best")
ax1.set_ylabel("Regression output")
ax1.set_xlabel("Input feature")
ax1.set_title("Result before discretization") 
plt.tight_layout()
plt.show()
#從這個圖像來看,可以得出什麼結果?

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-mQFXJdg5-1593261098360)(output_138_0.png)]

從圖像上可以看出,線性迴歸無法擬合出這條帶噪音的正弦曲線的真實面貌,只能夠模擬出大概的趨勢,而決策樹卻通過建立複雜的模型將幾乎每個點都擬合出來了。此時此刻,決策樹正處於過擬合的狀態,對數據的學習過於細緻,而線性迴歸處於擬合不足的狀態,這是由於模型本身只能夠在線性關係間進行擬合的性質決定的。爲了讓線性迴歸在類似的數據上變得更加強大,我們可以使用分箱,也就是離散化連續型變量的方法來處理原始數據,以此來提升線性迴歸的表現。來看看我們如何實現:

  1. 分箱及分箱的相關問題
from sklearn.preprocessing import KBinsDiscretizer
#將數據分箱
enc = KBinsDiscretizer(n_bins=10 #分幾類?
                       ,encode="onehot") #ordinal
X_binned = enc.fit_transform(X)
#encode模式"onehot":使用做啞變量方式做離散化
#之後返回一個稀疏矩陣(m,n_bins),每一列是一個分好的類別
#對每一個樣本而言,它包含的分類(箱子)中它表示爲1,其餘分類中它表示爲0
X.shape
(100, 1)
X_binned
<100x10 sparse matrix of type '<class 'numpy.float64'>'
	with 100 stored elements in Compressed Sparse Row format>
#使用pandas打開稀疏矩陣
import pandas as pd
pd.DataFrame(X_binned.toarray()).head()
0 1 2 3 4 5 6 7 8 9
0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
3 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
4 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
#我們將使用分箱後的數據來訓練模型,在sklearn中,測試集和訓練集的結構必須保持一致,否則報錯
LinearR_ = LinearRegression().fit(X_binned, y)
LinearR_.predict(line) #line作爲測試集
---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-141-abc35ebde7c7> in <module>
----> 1 LinearR_.predict(line) #line作爲測試集


D:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\_base.py in predict(self, X)
    223             Returns predicted values.
    224         """
--> 225         return self._decision_function(X)
    226 
    227     _preprocess_data = staticmethod(_preprocess_data)


D:\ProgramData\Anaconda3\lib\site-packages\sklearn\linear_model\_base.py in _decision_function(self, X)
    207         X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
    208         return safe_sparse_dot(X, self.coef_.T,
--> 209                                dense_output=True) + self.intercept_
    210 
    211     def predict(self, X):


D:\ProgramData\Anaconda3\lib\site-packages\sklearn\utils\extmath.py in safe_sparse_dot(a, b, dense_output)
    149             ret = np.dot(a, b)
    150     else:
--> 151         ret = a @ b
    152 
    153     if (sparse.issparse(a) and sparse.issparse(b)


ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 10 is different from 1)
line.shape #測試
(1000, 1)
X_binned.shape #訓練
(100, 10)
#因此我們需要創建分箱後的測試集:按照已經建好的分箱模型將line分箱
line_binned = enc.transform(line)
line_binned
<1000x10 sparse matrix of type '<class 'numpy.float64'>'
	with 1000 stored elements in Compressed Sparse Row format>
line_binned.shape #分箱後的數據是無法進行繪圖的
(1000, 10)
LinearR_.predict(line_binned)
array([-0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.22510103, -0.22510103, -0.22510103, -0.22510103, -0.22510103,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.68407735, -0.68407735, -0.68407735,
       -0.68407735, -0.68407735, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.84238714, -0.84238714,
       -0.84238714, -0.84238714, -0.84238714, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.90433112,
       -0.90433112, -0.90433112, -0.90433112, -0.90433112, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
       -0.72176296, -0.72176296, -0.72176296, -0.72176296, -0.72176296,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.01332773,
        0.01332773,  0.01332773,  0.01332773,  0.01332773,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.53043458,  0.53043458,  0.53043458,
        0.53043458,  0.53043458,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.98570463,  0.98570463,
        0.98570463,  0.98570463,  0.98570463,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.97481791,  0.97481791,
        0.97481791,  0.97481791,  0.97481791,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229,
        0.38539229,  0.38539229,  0.38539229,  0.38539229,  0.38539229])
LinearR_.predict(line_binned).shape
(1000,)
  1. 使用分箱數據進行建模和繪圖
#準備數據
enc = KBinsDiscretizer(n_bins=10,encode="onehot")
X_binned = enc.fit_transform(X)
line_binned = enc.transform(line)
#將兩張圖像繪製在一起,佈置畫布
fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True #讓兩張圖共享y軸上的刻度
                               , figsize=(10, 4))

#在圖1中佈置在原始數據上建模的結果
ax1.plot(line, LinearR.predict(line), linewidth=2, color='green',label="linear regression")
ax1.plot(line, TreeR.predict(line), linewidth=2, color='red',label="decision tree")
ax1.plot(X[:, 0], y, 'o', c='k')
ax1.legend(loc="best")
ax1.set_ylabel("Regression output")
ax1.set_xlabel("Input feature")
ax1.set_title("Result before discretization")

#使用分箱數據進行建模
LinearR_ = LinearRegression().fit(X_binned, y)
TreeR_ = DecisionTreeRegressor(random_state=0).fit(X_binned, y)

#進行預測,在圖2中佈置在分箱數據上進行預測的結果
ax2.plot(line #橫座標
         , LinearR_.predict(line_binned) #分箱後的特徵矩陣的結果
         , linewidth=2
         , color='green'
         , linestyle='-'
         , label='linear regression')
ax2.plot(line, TreeR_.predict(line_binned), linewidth=2, color='red',
         linestyle=':', label='decision tree')

#繪製和箱寬一致的豎線
ax2.vlines(enc.bin_edges_[0] # 設置豎線在x軸的位置
           , *plt.gca().get_ylim() # 設置豎線在y軸的上限和下限            
           , linewidth=1
           , alpha=.2)

#將原始數據分佈放置在圖像上
ax2.plot(X[:, 0], y, 'o', c='k')
#其他繪圖設定
ax2.legend(loc="best")
ax2.set_xlabel("Input feature")
ax2.set_title("Result after discretization") 
plt.tight_layout()
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-Gq2NO1tD-1593261098361)(output_154_0.png)]

enc.bin_edges_
array([array([-2.9668673 , -2.55299973, -2.0639171 , -1.3945301 , -1.02797432,
       -0.21514527,  0.44239288,  1.14612193,  1.63693428,  2.32784522,
        2.92132162])], dtype=object)
enc.bin_edges_[0] # 數組中包含的數值就是分箱後的上限和下限,
# 把這些上限和下限作爲豎線的x軸座標
array([-2.9668673 , -2.55299973, -2.0639171 , -1.3945301 , -1.02797432,
       -0.21514527,  0.44239288,  1.14612193,  1.63693428,  2.32784522,
        2.92132162])
plt.gca().get_ylim() # 獲取y軸的上限和下限
(0.0, 1.0)

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-N0dNmVcJ-1593261098362)(output_157_1.png)]

[*(plt.gca().get_ylim())]  # 加上*號表示,可以把元組中的數據取出來用
[0.0, 1.0]

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-nC3SyLQ6-1593261098363)(output_158_1.png)]

從圖像上可以看出,離散化後線性迴歸和決策樹上的預測結果完全相同了——線性迴歸比較成功地擬合了數據的分佈,而決策樹的過擬合效應也減輕了。由於特徵矩陣被分箱,因此特徵矩陣在每個區域內獲得的值是恆定的,因此所有模型對同一個箱中所有的樣本都會獲得相同的預測值。與分箱前的結果相比,線性迴歸明顯變得更加靈活,而決策樹的過擬合問題也得到了改善。但注意,一般來說我們是不使用分箱來改善決策樹的過擬合問題的,因爲樹模型帶有豐富而有效的剪枝功能來防止過擬合。

在這個例子中,我們設置的分箱箱數爲10,不難想到這個箱數的設定肯定會影響模型最後的預測結果,我們來看看不同的箱數會如何影響迴歸的結果:

  1. 箱子數如何影響模型的結果
#準備數據
enc = KBinsDiscretizer(n_bins=5,encode="onehot")
X_binned = enc.fit_transform(X)
line_binned = enc.transform(line)
#將兩張圖像繪製在一起,佈置畫布
fig, ax2 = plt.subplots(ncols=1, figsize=(5, 4))


#使用分箱數據進行建模
LinearR_ = LinearRegression().fit(X_binned, y)

print(LinearR_.score(line_binned,np.sin(line)))

TreeR_ = DecisionTreeRegressor(random_state=0).fit(X_binned, y)

#進行預測,在圖2中佈置在分箱數據上進行預測的結果
ax2.plot(line #橫座標
         , LinearR_.predict(line_binned) #分箱後的特徵矩陣的結果
         , linewidth=2
         , color='green'
         , linestyle='-'
         , label='linear regression')
ax2.plot(line, TreeR_.predict(line_binned), linewidth=2, color='red',
         linestyle=':', label='decision tree')

#繪製和箱寬一致的豎線
ax2.vlines(enc.bin_edges_[0] # 設置豎線在x軸的位置
           , *plt.gca().get_ylim() # 設置豎線在y軸的上限和下限            
           , linewidth=1
           , alpha=.2)

#將原始數據分佈放置在圖像上
ax2.plot(X[:, 0], y, 'o', c='k')
#其他繪圖設定
ax2.legend(loc="best")
ax2.set_xlabel("Input feature")
ax2.set_title("Result after discretization") 
plt.tight_layout()
plt.show()
0.8649069759304867

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-a09351O3-1593261098364)(output_162_1.png)]

  1. 如何選取最優的箱數
from sklearn.model_selection import cross_val_score as CVS 
import numpy as np

pred,score,var = [], [], []
binsrange = [2,5,10,15,20,30]
for i in binsrange:
    #實例化分箱類
    enc = KBinsDiscretizer(n_bins=i,encode="onehot")
    #轉換數據
    X_binned = enc.fit_transform(X)
    line_binned = enc.transform(line)
    #建立模型
    LinearR_ = LinearRegression()
    #全數據集上的交叉驗證
    cvresult = CVS(LinearR_,X_binned,y,cv=5)
    score.append(cvresult.mean())
    var.append(cvresult.var())
    #測試數據集上的打分結果
    pred.append(LinearR_.fit(X_binned,y).score(line_binned,np.sin(line)))
    
#繪製圖像
plt.figure(figsize=(6,5))
plt.plot(binsrange,pred,c="orange",label="test")
plt.plot(binsrange,score,c="k",label="full data")
plt.plot(binsrange,score+np.array(var)*0.5,c="red",linestyle="--",label = "var") 
plt.plot(binsrange,score-np.array(var)*0.5,c="red",linestyle="--")
plt.legend()
plt.show()

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-bvO4K7H1-1593261098365)(output_164_0.png)]

由上圖可知,選擇分箱數爲20箱是最佳的。因爲此時,方差最小且均值最高,模型最穩定。

5.3 多項式迴歸PolynomialFeatures

from sklearn.preprocessing import PolynomialFeatures 
import numpy as np

#如果原始數據是一維的
X = np.arange(1,4).reshape(-1,1)
X
array([[1],
       [2],
       [3]])
X.shape
(3, 1)
#二次多項式,參數degree控制多項式的次方
poly = PolynomialFeatures(degree=2)
#接口transform直接調用
X_ = poly.fit_transform(X)
X_
array([[1., 1., 1.],
       [1., 2., 4.],
       [1., 3., 9.]])
X_.shape
(3, 3)
#三次多項式
PolynomialFeatures(degree=3).fit_transform(X)
array([[ 1.,  1.,  1.,  1.],
       [ 1.,  2.,  4.,  8.],
       [ 1.,  3.,  9., 27.]])

不難注意到,多項式變化後數據看起來不太一樣了:首先,數據的特徵(維度)增加了,這正符合我們希望的將數據轉換到高維空間的願望。其次,維度的增加是有一定的規律的。不難發現,如果我們本來的特徵矩陣中只有一個特徵x,而轉換後我們得到:[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-4LWhknq8-1593261098365)(attachment:image.png)]

#三次多項式,不帶與截距項相乘的x0
PolynomialFeatures(degree=3,include_bias=False).fit_transform(X)
array([[ 1.,  1.,  1.],
       [ 2.,  4.,  8.],
       [ 3.,  9., 27.]])
#爲什麼我們會希望不生成與截距相乘的x0呢?
#對於多項式迴歸來說,我們已經爲線性迴歸準備好了x0,但是線性迴歸並不知道
xxx = PolynomialFeatures(degree=3).fit_transform(X)
xxx.shape
(3, 4)
rnd = np.random.RandomState(42) #設置隨機數種子
y = rnd.randn(3)
y
array([ 0.49671415, -0.1382643 ,  0.64768854])
#生成了多少個係數?
LinearRegression().fit(xxx,y).coef_
array([ 3.08086889e-15, -3.51045297e-01, -6.06987134e-01,  2.19575463e-01])
#查看截距
LinearRegression().fit(xxx,y).intercept_
1.2351711202036884
#發現問題了嗎?線性迴歸並沒有把多項式生成的x0當作是截距項
#所以我們可以選擇:關閉多項式迴歸中的include_bias
#也可以選擇:關閉線性迴歸中的fit_intercept
#生成了多少個係數?
LinearRegression(fit_intercept=False).fit(xxx,y).coef_
array([ 1.00596411,  0.06916756, -0.83619415,  0.25777663])
#查看截距
LinearRegression(fit_intercept=False).fit(xxx,y).intercept_
0.0

不過,這只是一維狀況的表達,大多數時候我們的原始特徵矩陣不可能會是一維的,至少也是二維以上,很多時候還可能存在上千個特徵或者維度。現在我們來看看原始特徵矩陣是二維的狀況:

X = np.arange(6).reshape(3, 2)
X
array([[0, 1],
       [2, 3],
       [4, 5]])
#嘗試二次多項式
PolynomialFeatures(degree=2).fit_transform(X)
array([[ 1.,  0.,  1.,  0.,  0.,  1.],
       [ 1.,  2.,  3.,  4.,  6.,  9.],
       [ 1.,  4.,  5., 16., 20., 25.]])

很明顯,上面一維的轉換公式已經不適用了,但如果我們仔細看,是可以看出這樣的規律的:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-mBn3uWzL-1593261098366)(attachment:image.png)]

當原始特徵爲二維的時候,多項式的二次變化突然將特徵增加到了六維,其中一維是常量(也就是截距)。當我們繼續適用線性迴歸去擬合的時候,我們會得到的方程如下:[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-TAfOtxgL-1593261098367)(attachment:1593253305%281%29.png)]

#嘗試三次多項式
PolynomialFeatures(degree=3).fit_transform(X)
array([[  1.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   1.],
       [  1.,   2.,   3.,   4.,   6.,   9.,   8.,  12.,  18.,  27.],
       [  1.,   4.,   5.,  16.,  20.,  25.,  64.,  80., 100., 125.]])

很明顯,我們可以看出這次生成的數據有這樣的規律:[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-6fMOd4sd-1593261098368)(attachment:image.png)]

不難發現:當我們進行多項式轉換的時候,多項式會產出到最高次數爲止的所有低高次項。比如如果我們規定多項式的次數爲2,多項式就會產出所有次數爲1和次數爲2的項反饋給我們,相應的如果我們規定多項式的次數爲n,則多項式會產出所有從次數爲1到次數爲n的項。注意, 和 一樣都是二次項,一個自變量的平方其實也就相當於是
,所以在三次多項式中 就是三次項。

在多項式迴歸中,我們可以規定是否產生平方或者立方項,其實如果我們只要求高次項的話, 會是一個比 更好的高次項,因爲 和 之間的共線性會比 與 之間的共線性好那麼一點點(只是一點點),而我們多項式轉化之後是需要使用線性迴歸模型來進行擬合的,就算機器學習中不是那麼在意數據上的基本假設,但是太過分的共線性還是會影響到模型的擬合。因此sklearn中存在着控制是否要生成平方和立方項的參數interaction_only,默認爲False,以減少共線性。來看這個參數是如何工作的:

PolynomialFeatures(degree=2).fit_transform(X)
array([[ 1.,  0.,  1.,  0.,  0.,  1.],
       [ 1.,  2.,  3.,  4.,  6.,  9.],
       [ 1.,  4.,  5., 16., 20., 25.]])
PolynomialFeatures(degree=2,interaction_only=True).fit_transform(X)
array([[ 1.,  0.,  1.,  0.],
       [ 1.,  2.,  3.,  6.],
       [ 1.,  4.,  5., 20.]])

對比之下,當interaction_only爲True的時候,只生成交互項

從之前的許多次嘗試中我們可以看出,隨着多項式的次數逐漸變高,特徵矩陣會被轉化得越來越複雜。不僅是次數,當特徵矩陣中的維度數(特徵數)增加的時候,多項式同樣會變得更加複雜:

#更高維度的原始特徵矩陣
X = np.arange(9).reshape(3, 3)
X
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
PolynomialFeatures(degree=2).fit_transform(X)
array([[ 1.,  0.,  1.,  2.,  0.,  0.,  0.,  1.,  2.,  4.],
       [ 1.,  3.,  4.,  5.,  9., 12., 15., 16., 20., 25.],
       [ 1.,  6.,  7.,  8., 36., 42., 48., 49., 56., 64.]])
PolynomialFeatures(degree=3).fit_transform(X)
array([[  1.,   0.,   1.,   2.,   0.,   0.,   0.,   1.,   2.,   4.,   0.,
          0.,   0.,   0.,   0.,   0.,   1.,   2.,   4.,   8.],
       [  1.,   3.,   4.,   5.,   9.,  12.,  15.,  16.,  20.,  25.,  27.,
         36.,  45.,  48.,  60.,  75.,  64.,  80., 100., 125.],
       [  1.,   6.,   7.,   8.,  36.,  42.,  48.,  49.,  56.,  64., 216.,
        252., 288., 294., 336., 384., 343., 392., 448., 512.]])
X_ = PolynomialFeatures(degree=20).fit_transform(X) 
X_.shape
(3, 1771)

如此,多項式變化對於數據會有怎樣的影響就一目瞭然了:隨着原特徵矩陣的維度上升,隨着我們規定的最高次數的上升,數據會變得越來越複雜,維度越來越多,並且這種維度的增加並不能用太簡單的數學公式表達出來。因此,多項式迴歸沒有固定的模型表達式,多項式迴歸的模型最終長什麼樣子是由數據和最高次數決定的,因此我們無法斷言說某個數學表達式"就是多項式迴歸的數學表達",因此要求解多項式迴歸不是一件容易的事兒,感興趣的大家可以自己去嘗試看看用最小二乘法求解多項式迴歸。接下來,我們就來看看多項式迴歸的根本作用:處理非線性問題。

5.3.2 多項式迴歸處理非線性問題

from sklearn.preprocessing import PolynomialFeatures as PF 
from sklearn.linear_model import LinearRegression
import numpy as np
rnd = np.random.RandomState(42) #設置隨機數種子
X = rnd.uniform(-3, 3, size=100)
y = np.sin(X) + rnd.normal(size=len(X)) / 3
#將X升維,準備好放入sklearn中
X = X.reshape(-1,1)
#創建測試數據,均勻分佈在訓練集X的取值範圍內的一千個點
line = np.linspace(-3, 3, 1000, endpoint=False).reshape(-1, 1)
#原始特徵矩陣的擬合結果
LinearR = LinearRegression().fit(X, y)
#對訓練數據的擬合
LinearR.score(X,y)
0.5361526059318595
#對測試數據的擬合
LinearR.score(line,np.sin(line))
0.6800102369793312
#多項式擬合,設定高次項
d=5 

#進行高此項轉換
poly = PF(degree=d)
X_ = poly.fit_transform(X) # 將訓練數據升維
line_ = PF(degree=d).fit_transform(line) # 將測試數據升維
#訓練數據的擬合
LinearR_ = LinearRegression().fit(X_, y) 
LinearR_.score(X_,y)
0.8561679370344799
#測試數據的擬合
LinearR_.score(line_,np.sin(line))
0.9868904451787978
import matplotlib.pyplot as plt
d=5
#和上面展示一致的建模流程
LinearR = LinearRegression().fit(X, y)
X_ = PF(degree=d).fit_transform(X)
LinearR_ = LinearRegression().fit(X_, y)
line = np.linspace(-3, 3, 1000, endpoint=False).reshape(-1, 1) 
line_ = PF(degree=d).fit_transform(line)
#放置畫布
fig, ax1 = plt.subplots(1)
#將測試數據帶入predict接口,獲得模型的擬合效果並進行繪製
ax1.plot(line, LinearR.predict(line), linewidth=2, color='green',label="linear regression")
ax1.plot(line, LinearR_.predict(line_), linewidth=2, color='red',label="Polynomial regression")
#將原數據上的擬合繪製在圖像上
ax1.plot(X[:, 0], y, 'o', c='k')
#其他圖形選項
ax1.legend(loc="best")
ax1.set_ylabel("Regression output")
ax1.set_xlabel("Input feature")
ax1.set_title("Linear Regression ordinary vs poly") 
plt.tight_layout()
plt.show()
#來一起鼓掌,感嘆多項式迴歸的神奇

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-fWjd3mVh-1593261098368)(output_208_0.png)]

從這裏大家可以看出,多項式迴歸能夠較好地擬合非線性數據,還不容易發生過擬合,可以說是保留了線性迴歸作爲線性模型所帶的“不容易過擬合”和“計算快速”的性質,同時又實現了優秀地擬合非線性數據。到了這裏,相信大家對於多項式迴歸的效果已經不再懷疑了。多項式迴歸非常迷人也非常神奇,因此一直以來都有各種各樣圍繞着多項式迴歸進行的討論。在這裏,爲大家梳理幾個常見問題和討論,供大家參考。

5.3.3 多項式迴歸的可解釋性

線性迴歸是一個具有高解釋性的模型,它能夠對每個特徵擬合出參數 以幫助我們理解每個特徵對於標籤的作用。當我們進行了多項式轉換後,儘管我們還是形成形如線性迴歸的方程,但隨着數據維度和多項式次數的上升,方程也變得異常複雜,我們可能無法一眼看出增維後的特徵是由之前的什麼特徵組成的(之前我們都是肉眼看肉眼判斷)。不過,多項式迴歸的可解釋性依然是存在的,我們可以使用接口get_feature_names來調用生成的新特徵矩陣的各個特徵上的名稱,以便幫助我們解釋模型。來看下面的例子:

import numpy as np
from sklearn.preprocessing import PolynomialFeatures 
from sklearn.linear_model import LinearRegression

X = np.arange(9).reshape(3, 3)
X
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
poly = PolynomialFeatures(degree=5).fit(X)
poly
PolynomialFeatures(degree=5, include_bias=True, interaction_only=False,
                   order='C')
#重要接口get_feature_names 
poly.get_feature_names()
['1',
 'x0',
 'x1',
 'x2',
 'x0^2',
 'x0 x1',
 'x0 x2',
 'x1^2',
 'x1 x2',
 'x2^2',
 'x0^3',
 'x0^2 x1',
 'x0^2 x2',
 'x0 x1^2',
 'x0 x1 x2',
 'x0 x2^2',
 'x1^3',
 'x1^2 x2',
 'x1 x2^2',
 'x2^3',
 'x0^4',
 'x0^3 x1',
 'x0^3 x2',
 'x0^2 x1^2',
 'x0^2 x1 x2',
 'x0^2 x2^2',
 'x0 x1^3',
 'x0 x1^2 x2',
 'x0 x1 x2^2',
 'x0 x2^3',
 'x1^4',
 'x1^3 x2',
 'x1^2 x2^2',
 'x1 x2^3',
 'x2^4',
 'x0^5',
 'x0^4 x1',
 'x0^4 x2',
 'x0^3 x1^2',
 'x0^3 x1 x2',
 'x0^3 x2^2',
 'x0^2 x1^3',
 'x0^2 x1^2 x2',
 'x0^2 x1 x2^2',
 'x0^2 x2^3',
 'x0 x1^4',
 'x0 x1^3 x2',
 'x0 x1^2 x2^2',
 'x0 x1 x2^3',
 'x0 x2^4',
 'x1^5',
 'x1^4 x2',
 'x1^3 x2^2',
 'x1^2 x2^3',
 'x1 x2^4',
 'x2^5']

使用加利佛尼亞房價數據集給大家作爲例子,當我們有標籤名稱的時候,可以直接在接口get_feature_names()中輸入標籤名稱來查看新特徵究竟是由原特徵矩陣中的什麼特徵組成的:

from sklearn.datasets import fetch_california_housing as fch 
import pandas as pd
housevalue = fch()
X = pd.DataFrame(housevalue.data)
y = housevalue.target
housevalue.feature_names
['MedInc',
 'HouseAge',
 'AveRooms',
 'AveBedrms',
 'Population',
 'AveOccup',
 'Latitude',
 'Longitude']
X.columns = ["住戶收入中位數","房屋使用年代中位數","平均房間數目","平均臥室數目","街區人口","平均入住率","街區的緯度","街區的經度"]
poly = PolynomialFeatures(degree=2).fit(X,y)
poly.get_feature_names(X.columns)
['1',
 '住戶收入中位數',
 '房屋使用年代中位數',
 '平均房間數目',
 '平均臥室數目',
 '街區人口',
 '平均入住率',
 '街區的緯度',
 '街區的經度',
 '住戶收入中位數^2',
 '住戶收入中位數 房屋使用年代中位數',
 '住戶收入中位數 平均房間數目',
 '住戶收入中位數 平均臥室數目',
 '住戶收入中位數 街區人口',
 '住戶收入中位數 平均入住率',
 '住戶收入中位數 街區的緯度',
 '住戶收入中位數 街區的經度',
 '房屋使用年代中位數^2',
 '房屋使用年代中位數 平均房間數目',
 '房屋使用年代中位數 平均臥室數目',
 '房屋使用年代中位數 街區人口',
 '房屋使用年代中位數 平均入住率',
 '房屋使用年代中位數 街區的緯度',
 '房屋使用年代中位數 街區的經度',
 '平均房間數目^2',
 '平均房間數目 平均臥室數目',
 '平均房間數目 街區人口',
 '平均房間數目 平均入住率',
 '平均房間數目 街區的緯度',
 '平均房間數目 街區的經度',
 '平均臥室數目^2',
 '平均臥室數目 街區人口',
 '平均臥室數目 平均入住率',
 '平均臥室數目 街區的緯度',
 '平均臥室數目 街區的經度',
 '街區人口^2',
 '街區人口 平均入住率',
 '街區人口 街區的緯度',
 '街區人口 街區的經度',
 '平均入住率^2',
 '平均入住率 街區的緯度',
 '平均入住率 街區的經度',
 '街區的緯度^2',
 '街區的緯度 街區的經度',
 '街區的經度^2']
X_ = poly.transform(X)
#在這之後,我們依然可以直接建立模型,然後使用線性迴歸的coef_屬性來查看什麼特徵對標籤的影響最大
reg = LinearRegression().fit(X_,y)
coef = reg.coef_
coef
array([ 5.91954055e-08, -1.12430252e+01, -8.48898543e-01,  6.44105898e+00,
       -3.15913288e+01,  4.06090344e-04,  1.00386234e+00,  8.70568188e+00,
        5.88063272e+00, -3.13081272e-02,  1.85994682e-03,  4.33020468e-02,
       -1.86142278e-01,  5.72831545e-05, -2.59019509e-03, -1.52505713e-01,
       -1.44242939e-01,  2.11725336e-04, -1.26219010e-03,  1.06115056e-02,
        2.81885293e-06, -1.81716947e-03, -1.00690372e-02, -9.99950167e-03,
        7.26947730e-03, -6.89064340e-02, -6.82365908e-05,  2.68878842e-02,
        8.75089875e-02,  8.22890339e-02,  1.60180950e-01,  5.14264271e-04,
       -8.71911472e-02, -4.37042992e-01, -4.04150578e-01,  2.73779577e-09,
        1.91426762e-05,  2.29529789e-05,  1.46567733e-05,  8.71560978e-05,
        2.13344592e-02,  1.62412938e-02,  6.18867358e-02,  1.08107173e-01,
        3.99077351e-02])
[*zip(poly.get_feature_names(X.columns),reg.coef_)]
[('1', 5.919540552171548e-08),
 ('住戶收入中位數', -11.243025193367279),
 ('房屋使用年代中位數', -0.848898543001153),
 ('平均房間數目', 6.441058980103585),
 ('平均臥室數目', -31.5913287885817),
 ('街區人口', 0.00040609034379415385),
 ('平均入住率', 1.003862338673655),
 ('街區的緯度', 8.705681884585069),
 ('街區的經度', 5.880632723650286),
 ('住戶收入中位數^2', -0.03130812716756933),
 ('住戶收入中位數 房屋使用年代中位數', 0.0018599468175778376),
 ('住戶收入中位數 平均房間數目', 0.04330204675617265),
 ('住戶收入中位數 平均臥室數目', -0.18614227806341782),
 ('住戶收入中位數 街區人口', 5.7283154455812717e-05),
 ('住戶收入中位數 平均入住率', -0.0025901950881940016),
 ('住戶收入中位數 街區的緯度', -0.15250571255697834),
 ('住戶收入中位數 街區的經度', -0.1442429393754428),
 ('房屋使用年代中位數^2', 0.00021172533625687026),
 ('房屋使用年代中位數 平均房間數目', -0.0012621900983789294),
 ('房屋使用年代中位數 平均臥室數目', 0.010611505608370727),
 ('房屋使用年代中位數 街區人口', 2.818852930851565e-06),
 ('房屋使用年代中位數 平均入住率', -0.0018171694688044425),
 ('房屋使用年代中位數 街區的緯度', -0.010069037156389547),
 ('房屋使用年代中位數 街區的經度', -0.009999501671412017),
 ('平均房間數目^2', 0.007269477298002129),
 ('平均房間數目 平均臥室數目', -0.06890643404856586),
 ('平均房間數目 街區人口', -6.823659076329969e-05),
 ('平均房間數目 平均入住率', 0.026887884152557523),
 ('平均房間數目 街區的緯度', 0.0875089875407275),
 ('平均房間數目 街區的經度', 0.08228903389524618),
 ('平均臥室數目^2', 0.1601809500092068),
 ('平均臥室數目 街區人口', 0.0005142642707304053),
 ('平均臥室數目 平均入住率', -0.08719114715677954),
 ('平均臥室數目 街區的緯度', -0.43704299179225914),
 ('平均臥室數目 街區的經度', -0.4041505775830314),
 ('街區人口^2', 2.737795767870921e-09),
 ('街區人口 平均入住率', 1.914267616391803e-05),
 ('街區人口 街區的緯度', 2.2952978919604794e-05),
 ('街區人口 街區的經度', 1.4656773311472193e-05),
 ('平均入住率^2', 8.715609781424712e-05),
 ('平均入住率 街區的緯度', 0.021334459219533943),
 ('平均入住率 街區的經度', 0.01624129382914855),
 ('街區的緯度^2', 0.06188673577348754),
 ('街區的緯度 街區的經度', 0.10810717324450632),
 ('街區的經度^2', 0.039907735079891565)]
#放到dataframe中進行排序
coeff = pd.DataFrame([poly.get_feature_names(X.columns),reg.coef_.tolist()]).T 
coeff.columns = ["feature","coef"]
coeff.sort_values(by="coef") # df.sort_values(by="coef") 按照coef字段的值進行排序,默認升序
feature coef
4 平均臥室數目 -31.5913
1 住戶收入中位數 -11.243
2 房屋使用年代中位數 -0.848899
33 平均臥室數目 街區的緯度 -0.437043
34 平均臥室數目 街區的經度 -0.404151
12 住戶收入中位數 平均臥室數目 -0.186142
15 住戶收入中位數 街區的緯度 -0.152506
16 住戶收入中位數 街區的經度 -0.144243
32 平均臥室數目 平均入住率 -0.0871911
25 平均房間數目 平均臥室數目 -0.0689064
9 住戶收入中位數^2 -0.0313081
22 房屋使用年代中位數 街區的緯度 -0.010069
23 房屋使用年代中位數 街區的經度 -0.0099995
14 住戶收入中位數 平均入住率 -0.0025902
21 房屋使用年代中位數 平均入住率 -0.00181717
18 房屋使用年代中位數 平均房間數目 -0.00126219
26 平均房間數目 街區人口 -6.82366e-05
35 街區人口^2 2.7378e-09
0 1 5.91954e-08
20 房屋使用年代中位數 街區人口 2.81885e-06
38 街區人口 街區的經度 1.46568e-05
36 街區人口 平均入住率 1.91427e-05
37 街區人口 街區的緯度 2.2953e-05
13 住戶收入中位數 街區人口 5.72832e-05
39 平均入住率^2 8.71561e-05
17 房屋使用年代中位數^2 0.000211725
5 街區人口 0.00040609
31 平均臥室數目 街區人口 0.000514264
10 住戶收入中位數 房屋使用年代中位數 0.00185995
24 平均房間數目^2 0.00726948
19 房屋使用年代中位數 平均臥室數目 0.0106115
41 平均入住率 街區的經度 0.0162413
40 平均入住率 街區的緯度 0.0213345
27 平均房間數目 平均入住率 0.0268879
44 街區的經度^2 0.0399077
11 住戶收入中位數 平均房間數目 0.043302
42 街區的緯度^2 0.0618867
29 平均房間數目 街區的經度 0.082289
28 平均房間數目 街區的緯度 0.087509
43 街區的緯度 街區的經度 0.108107
30 平均臥室數目^2 0.160181
6 平均入住率 1.00386
8 街區的經度 5.88063
3 平均房間數目 6.44106
7 街區的緯度 8.70568
reg.coef_
array([ 4.36693293e-01,  9.43577803e-03, -1.07322041e-01,  6.45065694e-01,
       -3.97638942e-06, -3.78654265e-03, -4.21314378e-01, -4.34513755e-01])
reg.coef_.tolist()
[0.4366932931343249,
 0.00943577803323849,
 -0.10732204139090416,
 0.6450656935198118,
 -3.97638942118729e-06,
 -0.0037865426549709763,
 -0.42131437752714423,
 -0.4345137546747774]

可以發現,不僅數據的可解釋性還存在,我們還可以通過這樣的手段做特徵工程——特徵創造。多項式幫助我們進行了一系列特徵之間相乘的組合,若能夠找出組合起來後對標籤貢獻巨大的特徵,那我們就是創造了新的有效特徵,對於任何學科而言發現新特徵都是非常有價值的。

在加利佛尼亞房屋價值數據集上來再次確認多項式迴歸提升模型表現的能力:

#順便可以查看一下多項式變化之後,模型的擬合效果如何了
poly = PolynomialFeatures(degree=4).fit(X,y)
X_ = poly.transform(X)
reg = LinearRegression().fit(X,y)
reg.score(X,y)
0.6062326851998049
from time import time
time0 = time()
reg_ = LinearRegression().fit(X_,y)
print("R2:{}".format(reg_.score(X_,y)))
print("time:{}".format(time()-time0))
R2:0.7452006076442239
time:0.8058722019195557
#假設使用其他模型?
from sklearn.ensemble import RandomForestRegressor as RFR
time0 = time()
print("R2:{}".format(RFR(n_estimators=100).fit(X,y).score(X,y))) 
print("time:{}".format(time()-time0))
R2:0.9743876150808999
time:10.486432075500488

總結

本篇文章中,主要講解了多元線性迴歸嶺迴歸Lasso多項式迴歸總計四個算法,他們都是圍繞着原始的線性迴歸進行的拓展和改進。其中嶺迴歸和Lasso是爲了解決多元線性迴歸中使用最小二乘法的各種限制,主要用途是消除多重共線性帶來的影響並且做特徵選擇,而多項式迴歸解決了線性迴歸無法擬合非線性數據的明顯缺點,核心作用是提升模型的表現

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章