DL多項式升維(代碼實戰)

學了本文你能學到什麼?僅供學習,如有疑問,請留言。。。

 

# -*- coding: utf-8 -*-
# Author       :   szy
# Create Date  :   2019/10/30
# 多項式升維,實戰,訓練模型喝估值
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

np.random.seed(42)
m = 100
# -3到3的區間
X = 6 * np.random.rand(m, 1) - 3
y = 0.5 * X**2 + X + 2 + np.random.randn(m, 1)
plt.plot(X, y, 'b.')
# plt.show()
X_train = X[:80]
y_train = X[:80]
X_test = X[80:]
y_test = X[80:]

d = {1: "g-", 2: "r+", 10: "y*"}
for i in d:
    # 開始升維了
    poly_features = PolynomialFeatures(degree=i, include_bias=True)
    X_poly_train = poly_features.fit_transform(X_train)
    X_poly_test = poly_features.fit_transform(X_test)
    print(X_train[0])
    print(X_train.shape)
    print(X_poly_train[0])
    print(X_poly_train.shape)
    print("-------------------------------------------------------------------")
    lin_reg = LinearRegression(fit_intercept=False)
    lin_reg.fit(X_poly_train, y_train)
    print(lin_reg.intercept_, lin_reg.coef_)
    # 看看是否隨着degree的增加升維,是否過擬合了
    y_train_predict = lin_reg.predict(X_poly_train)
    y_test_predict = lin_reg.predict(X_poly_test)

    plt.plot(X_poly_train[:, 1], y_train_predict, d[i])

    print(mean_squared_error(y_train, y_train_predict))
    print(mean_squared_error(y_test, y_test_predict))
plt.show()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章