advcase.py-20180704

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul  3 17:43:53 2018

@author: vicky
"""

# 導入第三方包
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from sklearn.cross_validation import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import seaborn as sns

#數據集中各變量的描述性統計分析
#data = pd.read_csv('C:\Users\wenyun.wxw\Desktop\Advertising.csv')
data.describe()

#對比三個因子與y的散點圖
sns.pairplot(data, x_vars=["TV", "Radio", "Newspaper"], y_vars="Sales",size=5,aspect=0.7, kind="reg")
sns.plt.show()

#相關係數
corr=data.corr()
#相關圖
plt.imshow(data.corr(), cmap=plt.cm.Blues, interpolation='nearest')
plt.colorbar()
tick_marks = [i for i in range(len(data.columns))]
plt.xticks(tick_marks, data.columns, rotation='vertical')
plt.yticks(tick_marks, data.columns)

#訓練集和測試集二八分
Train,Test = train_test_split(data, train_size = 0.8, random_state=1234)

#建線性迴歸模型
fit = smf.ols('Sales~TV+Radio+Newspaper', data = Train).fit()
fit.summary()

#去掉newspaper
fit2 = smf.ols('Sales~TV+Radio', data = Train.drop('Newspaper', axis = 1)).fit()
fit2.summary()

#加交互作用
fit3 = smf.ols('Sales~TV+Radio+TV:Radio', data = Train.drop('Newspaper', axis = 1)).fit()
fit3.summary()

pred = fit.predict(exog = Test)
pred2 = fit2.predict(exog = Test.drop('Newspaper', axis = 1))
pred3 = fit3.predict(exog = Test.drop('Newspaper', axis = 1))

#均方根誤差
RMSE = np.sqrt(mean_squared_error(Test.Sales, pred))
RMSE2 = np.sqrt(mean_squared_error(Test.Sales, pred2))
RMSE3 = np.sqrt(mean_squared_error(Test.Sales, pred3))
print('RMES=%.4f\n' %RMSE)
print('RMES=%.4f\n' %RMSE2)
print('RMES=%.4f\n' %RMSE3)

#畫真實值與預測值的對比圖
plt.style.use('ggplot')
plt.scatter(Test.Sales, pred,c='b',label = 'Observations')
plt.plot([Test.Sales.min(), Test.Sales.max()], [pred.min(), pred.max()], 'r--', lw=2, label = 'Fitted line')
plt.title('Real Values VS. Predict Values')
plt.xlabel('Real Values')
plt.ylabel('Prediction Values')
plt.legend(loc = 'upper left')
plt.show()

#畫殘差圖
plt.style.use('ggplot')
plt.scatter(pred,Test.Sales-pred,c='b',label = 'Residuals')
#plt.plot([pred.min(), pred.max()],[(Test.Sales-pred).min(), (Test.Sales-pred).max()],  'r--', lw=2, label = 'Fitted line')
plt.title('Residual Plot')
plt.xlabel('Fitted Values')
plt.ylabel('Residuals')
plt.legend(loc = 'upper left')
plt.show()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章