Company_profit_analysis.py

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import linear_model
from mpl_toolkits.mplot3d import axes3d
import seaborn as sns
#讀取數據
readdata = pd.read_csv('Advertising.csv')
data = np.array(readdata.values)#數值變成矩陣
#訓練數據
x_train = data[0:150,1:3]#取數據的0道150行,1,2列
y_train = data[0:150,4]

#測試數據
x_test = data[150:200,1:3]
y_test = data[150:200,4]

#迴歸分析
regr = linear_model.LinearRegression()
regr.fit(x_train,y_train)

#打印出相關係數和截距等信息
print('Cofficients: \n',regr.coef)#打印係數
print('Intercept:',regr.intercept
)#打印截距
#mean square error
print('Residual sum of squares: %.2f'%np.mean((regr.predict(x_test)-y_test)))
#explained variance score:1 is perfect prediction
print('variance score: %.2f' %regr.score(x_test,y_test))

#得出迴歸函數 並自定義數據
x_line = np.linspace(0,300)
y_line = np.linspace(0,50)
z_line = 0.04699836x_line+0.17913965y_line+3.0043106117608556

#畫圖
fig = plt.figure()
ax = plt.subplot(111,projection = '3d')#創建一個3維的繪圖工具
ax.scatter(data[:,1],data[:,2],data[:,4],c = 'red')
ax.plot(x_line,y_line,z_line,c = 'blue')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章