#!/usr/bin/env python
# coding: utf-8
# In[8]:
#下面是数据库的一些操作
import pandas as pd
import pymysql
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from matplotlib import pyplot as plt
db = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='chuang199832', db='data', charset='utf8')
cursor = db.cursor()
cursor.execute("select version()")
data = cursor.fetchone()
print(" Database Version:%s" % data)
sql = "select * from data"
cursor.execute(sql)
data = cursor.fetchall()
df = pd.DataFrame(data)
df.columns=["Trust",'Quantity','Satisfied','Loyal']
df.head()
# In[10]:
#这一格可以忽略
xls = pd.ExcelFile('./datasets/data.xls')
print(pd.read_excel(xls))
#data = pd.read_excel(xls)
# In[11]:
#读取数据
data = df
new_data = data.iloc[:,:4]
y = data.iloc[:,3]
print('head:',new_data.head(),'\nShape:',new_data.shape)
print('value:',y)
# In[12]:
#查看统计变量
print(new_data.describe())
# In[13]:
#三者均强相关,大于0.6就相关了
print(new_data.corr())
# In[67]:
#绘制相关性图片,发现具有明显的线性关系
sns.pairplot(new_data,x_vars = ['Trust','Quantity','Satisfied'],y_vars = 'Loyal',height=7, aspect=0.8,kind = 'reg')
plt.savefig('cor.jpg')
plt.show()
# In[14]:
#划分训练集和测试集
X_train,X_test,Y_train,Y_test = train_test_split(new_data.iloc[:,:3],new_data.Loyal,train_size=.80)
print("原始数据特征:",new_data.iloc[:,:3].shape,
",训练数据特征:",X_train.shape,
",测试数据特征:",X_test.shape)
print("原始数据标签:",new_data.Loyal.shape,
",训练数据标签:",Y_train.shape,
",测试数据标签:",Y_test.shape)
# In[15]:
#模型训练及输出结果
model = LinearRegression()
model.fit(X_train,Y_train)
a = model.intercept_#截距
b = model.coef_#回归系数
print("最佳拟合线:截距",a,",回归系数:",b)
# In[16]:
#R方检测,也就是误差检测的一种,R越高越好
score = model.score(X_test,Y_test)
print(score)
#输出预测值
Y_pred = model.predict(X_test)
print(Y_pred)
plt.plot(range(len(Y_pred)),Y_pred,'b',label ="predict")
# In[17]:
#下面显示预测数据和源数据对比图
plt.plot(range(len(Y_pred)),Y_pred,'b',label="predict")
plt.plot(range(len(Y_pred)),Y_test,'r',label="test")
plt.legend(loc="upper right") #显示图中的标签
plt.xlabel("the number of Loyal")
plt.ylabel('value of Loyal')
plt.savefig("ROC.jpg")
plt.show()
# In[ ]:
示例截图如下: